diff --git a/src/base_def.h b/src/base_def.h index 321b288..6409a6e 100644 --- a/src/base_def.h +++ b/src/base_def.h @@ -113,6 +113,10 @@ namespace MultiArrayTools template class MultiArrayOperation; + // multi_array_operation.h + template + class MultiArrayContraction; + // slice.h template class Slice; diff --git a/src/multi_array_operation.cc b/src/multi_array_operation.cc index 4f296db..113605a 100644 --- a/src/multi_array_operation.cc +++ b/src/multi_array_operation.cc @@ -34,7 +34,7 @@ namespace MultiArrayTools template void MultiArrayOperationRoot::performAssignment(const MultiArrayOperationBase& in) { -#error "WRITE MAOR INTRINSIC CONTRACT FUNCTION" + //#error "WRITE MAOR INTRINSIC CONTRACT FUNCTION" //CHECK; in.linkIndicesTo(MAOB::mIibPtr); //CHECK; @@ -42,6 +42,7 @@ namespace MultiArrayTools //CHECK; const size_t endPos = mArrayRef.endIndex().pos(); std::cout << "assignment: " << endPos << " elements" << std::endl; + // assignment loop for(iref = mArrayRef.beginIndex().pos(); iref != mArrayRef.endIndex(); ++iref){ std::cout << iref.pos() << '\r' << std::flush; get() = in.get(); @@ -50,6 +51,7 @@ namespace MultiArrayTools MAOB::mIibPtr->freeLinked(); } + /* template template MultiArrayOperationRoot& @@ -70,7 +72,7 @@ namespace MultiArrayTools sl.setConst(in.mArrayRef, name(), dynamic_cast( in.index() ), in.name()); return *this; } - + */ // CONST SLICE !!!!! @@ -102,10 +104,9 @@ namespace MultiArrayTools MultiArrayOperationRoot::operator=(MultiArrayOperationRoot& in) { //CHECK; - maketurnSlice(in); - if(mArrayRef.isSlice() and not mArrayRef.isInit()){ - return makeSlice(in); - } + // if(mArrayRef.isSlice() and not mArrayRef.isInit()){ + // return makeSlice(in); + //} performAssignment(in); freeIndex(); return *this; @@ -117,9 +118,9 @@ namespace MultiArrayTools MultiArrayOperationRoot::operator=(MultiArrayOperationRoot& in) { //CHECK; - if(mArrayRef.isSlice() and not mArrayRef.isInit()){ - return makeSlice(in); - } + //if(mArrayRef.isSlice() and not mArrayRef.isInit()){ + // return makeSlice(in); + //} performAssignment(in); freeIndex(); return *this; @@ -132,10 +133,10 @@ namespace MultiArrayTools MultiArrayOperationRoot::operator=(const MultiArrayOperationRoot& in) { //CHECK; - if(mArrayRef.isSlice() and not mArrayRef.isInit()){ + //if(mArrayRef.isSlice() and not mArrayRef.isInit()){ //CHECK; - return makeConstSlice(in); - } + // return makeConstSlice(in); + //} performAssignment(in); freeIndex(); return *this; @@ -147,11 +148,11 @@ namespace MultiArrayTools MultiArrayOperationRoot::operator=(const MultiArrayOperation& in) { //CHECK; - if(mArrayRef.isSlice() and not mArrayRef.isInit()){ + //if(mArrayRef.isSlice() and not mArrayRef.isInit()){ // NO SLICE CREATION !!! (total array not initialized!!) // throw ! - assert(0); - } + // assert(0); + //} performAssignment(in); freeIndex(); return *this; @@ -166,6 +167,31 @@ namespace MultiArrayTools return MultiArrayOperation, MAOps...>(op, *this, secs...); } + template + template + MultiArrayContraction > + MultiArrayOperationRoot::contract(const ContractOperation& cop, + const std::string& indexName) const + { + typename Range2::IndexType* ind = dynamic_cast( mIndex.getLinked(indexName) ); + return MultiArrayContraction >(cop, *this, *ind); + } + + template + template + MultiArrayContraction > + MultiArrayOperationRoot::contract(const ContractOperation& cop, + const std::string& indexName, + const typename Range2::IndexType& begin, + const typename Range2::IndexType& end) const + { + typename Range2::IndexType* ind = dynamic_cast( mIndex.getLinked(indexName) ); + return MultiArrayContraction >(cop, *this, *ind, begin, end); + } + + template template auto MultiArrayOperationRoot::operator+(const MAOp& sec) @@ -411,6 +437,31 @@ namespace MultiArrayTools return MultiArrayOperation, MAOps...>(op, *this, secs...); } + template + template + MultiArrayContraction > + ConstMultiArrayOperationRoot::contract(const ContractOperation& cop, + const std::string& indexName) const + { + typename Range2::IndexType* ind = dynamic_cast( mIndex.getLinked(indexName) ); + return MultiArrayContraction >(cop, *this, *ind); + } + + template + template + MultiArrayContraction > + ConstMultiArrayOperationRoot::contract(const ContractOperation& cop, + const std::string& indexName, + const typename Range2::IndexType& begin, + const typename Range2::IndexType& end) const + { + typename Range2::IndexType* ind = dynamic_cast( mIndex.getLinked(indexName) ); + return MultiArrayContraction >(cop, *this, *ind, begin, end); + } + + template template auto ConstMultiArrayOperationRoot::operator+(const MAOp& sec) const @@ -639,20 +690,55 @@ namespace MultiArrayTools TupleIndicesLinker::linkTupleIndicesTo(mArgs, target); } - /* - template - T& MultiArrayOperation::get() - { - mVal = OperationCall:: - template callOperation(mOp, mArgs); - return mVal; - }*/ - - template + + template const T& MultiArrayOperation::get() const { mVal = OperationCall:: template callOperation(mOp, mArgs); return mVal; } + + /******************************* + * MultiArrayContraction * + *******************************/ + + template + MultiArrayContraction:: + MultiArrayContraction(const ContractOperation& cop, const MAOp& mao, + const typename Range::IndexType& runIndex) : + MultiArrayOperation(cop, mao), + mBeginIndex(runIndex), mEndIndex(runIndex), + mRunIndex(runIndex) + { + mBeginIndex.setPos(0); + mEndIndex.setPos(mRunIndex.max()); + MAO::linkIndicesTo(&mRunIndex); + } + + template + MultiArrayContraction:: + MultiArrayContraction(const ContractOperation& cop, const MAOp& mao, + const typename Range::IndexType& runIndex, + const typename Range::IndexType& beginIndex, + const typename Range::IndexType& endIndex) : + MultiArrayOperation(cop, mao), + mBeginIndex(beginIndex), mEndIndex(endIndex), + mRunIndex(runIndex) + { + MAO::linkIndicesTo(&mRunIndex); + } + + + // for the moment simplest case only: + template + const T& MultiArrayContraction::get() const + { + MAO::mOp.reset(); + for(mRunIndex.copyPos( mBeginIndex ); mRunIndex.pos() != mEndIndex.pos(); ++mRunIndex){ + MAO::mOp(std::get<0>(MAO::mArgs).get() ); + } + MAO::mOp.endOp(MAO::mVal); + return MAO::mOp(); + } } diff --git a/src/multi_array_operation.h b/src/multi_array_operation.h index 64c7ecc..def9563 100644 --- a/src/multi_array_operation.h +++ b/src/multi_array_operation.h @@ -73,6 +73,17 @@ namespace MultiArrayTools template MultiArrayOperation, MAOps...> operator()(const Operation& op, const MAOps&... secs) const; + + template < class Range2, class ContractOperation> + MultiArrayContraction > + contract(const ContractOperation& cop, const std::string& indexName) const; + + template + MultiArrayContraction > + contract(const ContractOperation& cop, const std::string& indexName, + const typename Range2::IndexType& begin, + const typename Range2::IndexType& end) const; + template auto operator+(const MAOp& sec) -> decltype(operator()(std::plus(), sec)); @@ -131,13 +142,13 @@ namespace MultiArrayTools void performAssignment(const MultiArrayOperationBase& in); + /* template MultiArrayOperationRoot& makeSlice(MultiArrayOperationRoot& in); - template const MultiArrayOperationRoot& makeConstSlice(const MultiArrayOperationRoot& in); - + */ MutableMultiArrayBase& mArrayRef; mutable IndexType mIndex; Name mNm; @@ -155,23 +166,20 @@ namespace MultiArrayTools ConstMultiArrayOperationRoot(const MultiArrayBase& ma, const Name& nm); ConstMultiArrayOperationRoot(const MultiArrayOperationRoot& in); - /* - const ConstMultiArrayOperationRoot& operator=(const ConstMultiArrayOperationRoot& in); - - template - const ConstMultiArrayOperationRoot& operator=(const ConstMultiArrayOperationRoot& in); - - template - const ConstMultiArrayOperationRoot& operator=(const MultiArrayOperationRoot& in); - */ - - //template - //MultiArrayOperation, MAOps...> - //operator()(Operation& op, const MAOps&... secs) const; - template MultiArrayOperation, MAOps...> operator()(const Operation& op, const MAOps&... secs) const; + + template + MultiArrayContraction > + contract(const ContractOperation& cop, const std::string& indexName) const; + + + template + MultiArrayContraction > + contract(const ContractOperation& cop, const std::string& indexName, + const typename Range2::IndexType& begin, + const typename Range2::IndexType& end) const; template auto operator+(const MAOp& sec) const -> decltype(operator()(std::plus(), sec)); @@ -262,7 +270,6 @@ namespace MultiArrayTools virtual void linkIndicesTo(IndefinitIndexBase* target) const override; - //virtual T& get() override; virtual const T& get() const override; protected: @@ -272,8 +279,30 @@ namespace MultiArrayTools OBT mArgs; // include first arg also here !!! }; + template + class MultiArrayContraction : public MultiArrayOperation + { + public: + typedef MultiArrayOperationBase MAOB; + typedef MultiArrayOperation MAO; + MultiArrayContraction(const ContractOperation& cop, const MAOp& mao, + const typename Range::IndexType& runIndex); + + MultiArrayContraction(const ContractOperation& cop, const MAOp& mao, + const typename Range::IndexType& runIndex, + const typename Range::IndexType& beginIndex, + const typename Range::IndexType& endIndex); + + virtual const T& get() const override; + + protected: + typename Range::IndexType mBeginIndex; + typename Range::IndexType mEndIndex; + mutable typename Range::IndexType mRunIndex; + }; + } #include "multi_array_operation.cc" diff --git a/src/unit_test.cc b/src/unit_test.cc index 769e71a..f1c7298 100644 --- a/src/unit_test.cc +++ b/src/unit_test.cc @@ -9,6 +9,33 @@ namespace MAT = MultiArrayTools; namespace { + + template + struct sum + { + public: + sum() = default; + + T& operator()() const + { + return res; + } + + T& operator()(const T& a) const + { + return res += a; + } + + void endOp(T& res) const {} + + void reset() const + { + res = static_cast(0); + } + + private: + mutable T res = static_cast(0); + }; class OneDimTest : public ::testing::Test { @@ -135,6 +162,32 @@ namespace { MultiArray3dAny ma; //Slice2d3dAny sl; }; + + class ContractionTest : public ::testing::Test + { + protected: + typedef MAT::SingleRange Range1dAny; + typedef MAT::MultiRange Range2dAny; + typedef MAT::MultiRange Range3dAny; + typedef MAT::MultiArray MultiArray3dAny; + typedef MAT::Slice Slice2d3dAny; + typedef MAT::MultiArray MultiArray2dAny; + + ContractionTest() : r1({'a','b','c'}), r2({'a','b','c','d'}), r3({'a','b'}), + ra(r1,r3), + rb(r1,r2), + r3d(r1,r2,r3), + ma(r3d, {-5,6,2,1,9,54,27,-7,-13,32,90,-67, + -10,16,-2,101,39,-64,81,-22,14,34,95,-62}) {} + + Range1dAny r1; + Range1dAny r2; + Range1dAny r3; + Range2dAny ra; + Range2dAny rb; + Range3dAny r3d; + MultiArray3dAny ma; + }; TEST_F(OneDimTest, CorrectExtensions) { @@ -454,6 +507,26 @@ namespace { EXPECT_EQ(sl[j(j1 = 2, j2 = 0)], 14); EXPECT_EQ(sl[j(j1 = 2, j2 = 1)], 34); } + + TEST_F(ContractionTest, ContractionWorks) + { + MultiArray2dAny ma2(ra); + + ma2("alpha","gamma") = ma("alpha","beta","gamma").contract(sum(),"beta"); + + auto i = ma2.beginIndex(); + auto i1 = i.template getIndex<0>(); + auto i2 = i.template getIndex<1>(); + + EXPECT_EQ(ma2[i(i1 = 0, i2 = 0)], 33); + EXPECT_EQ(ma2[i(i1 = 0, i2 = 1)], 54); + + EXPECT_EQ(ma2[i(i1 = 1, i2 = 0)], 65); + EXPECT_EQ(ma2[i(i1 = 1, i2 = 1)], 82); + + EXPECT_EQ(ma2[i(i1 = 2, i2 = 0)], 229); + EXPECT_EQ(ma2[i(i1 = 2, i2 = 1)], -114); + } } // end namespace