diff --git a/src/base_def.h b/src/base_def.h index cd71a3c..6552a4b 100644 --- a/src/base_def.h +++ b/src/base_def.h @@ -114,7 +114,7 @@ namespace MultiArrayTools class MultiArrayOperation; // multi_array_operation.h - template + template class MultiArrayContraction; // slice.h diff --git a/src/multi_array.h b/src/multi_array.h index c3b0ec9..d0d1242 100644 --- a/src/multi_array.h +++ b/src/multi_array.h @@ -174,6 +174,12 @@ namespace MultiArrayTools virtual bool isConst() const override; + template + ConstMultiArrayOperationRoot operator()(bool x, const NameTypes&... str) const + { + return MAB::operator()(str...); + } + template MultiArrayOperationRoot operator()(const NameTypes&... str); diff --git a/src/multi_array_operation.cc b/src/multi_array_operation.cc index 7df054c..e5f6ef1 100644 --- a/src/multi_array_operation.cc +++ b/src/multi_array_operation.cc @@ -147,17 +147,21 @@ namespace MultiArrayTools MultiArrayOperationRoot& MultiArrayOperationRoot::operator=(const MultiArrayOperation& in) { - //CHECK; - //if(mArrayRef.isSlice() and not mArrayRef.isInit()){ - // NO SLICE CREATION !!! (total array not initialized!!) - // throw ! - // assert(0); - //} performAssignment(in); freeIndex(); return *this; } - + + template + template + MultiArrayOperationRoot& + MultiArrayOperationRoot::operator=(const MultiArrayContraction& in) + { + performAssignment(in); + freeIndex(); + return *this; + } + template template MultiArrayOperation, MAOps...> @@ -293,6 +297,12 @@ namespace MultiArrayTools return 1; } + template + IndefinitIndexBase* MultiArrayOperationRoot::getLinked(const std::string& name) const + { + return mIndex.getLinked(name); + } + template void MultiArrayOperationRoot::linkIndicesTo(IndefinitIndexBase* target) const { @@ -534,6 +544,12 @@ namespace MultiArrayTools return 1; } + template + IndefinitIndexBase* ConstMultiArrayOperationRoot::getLinked(const std::string& name) const + { + return mIndex.getLinked(name); + } + template void ConstMultiArrayOperationRoot::linkIndicesTo(IndefinitIndexBase* target) const { @@ -630,6 +646,37 @@ namespace MultiArrayTools } }; + template + struct LinkedIndexGetter + { + template + static IndefinitIndexBase* getLinked(const Tuple& optuple, + const std::string& name, + IndefinitIndexBase* current) + { + if(current == nullptr){ + current = std::get(optuple).getLinked(name); + LinkedIndexGetter::getLinked(optuple, name, current); + } + return current; + } + }; + + template <> + struct LinkedIndexGetter<0> + { + template + static IndefinitIndexBase* getLinked(const Tuple& optuple, + const std::string& name, + IndefinitIndexBase* current) + { + if(current == nullptr){ + current = std::get<0>(optuple).getLinked(name); + } + return current; + } + }; + template MultiArrayOperation:: MultiArrayOperation(Operation& op, const MAOps&... args) : @@ -662,37 +709,17 @@ namespace MultiArrayTools template - template - MultiArrayContraction > + template + MultiArrayContraction,MAOps2...> MultiArrayOperation:: - contract(const ContractOperation& cop, const std::string& indexName) const + contract(const ContractOperation& cop, + const std::string& indexName, + const MAOps2&... mao) const { -#error "HERE" - typename Range2::IndexType* ind = dynamic_cast( mIndex.getLinked(indexName) ); - //typename Range2::IndexType ind = Range2().begin(); - //ind.name(indexName); - return MultiArrayContraction >(cop, *this, *ind); - + typename Range2::IndexType* ind = dynamic_cast( getLinked(indexName) ); + return MultiArrayContraction, + MAOps2...>(cop, *ind, *this, mao...); } - - /* - template - template - MultiArrayContraction > - MultiArrayOperation:: - contract(const ContractOperation& cop, const std::string& indexName, - const typename Range2::IndexType& begin, - const typename Range2::IndexType& end) const - { - typename Range2::IndexType* ind = dynamic_cast( mIndex.getLinked(indexName) ); - //typename Range2::IndexType ind = Range2().begin(); - //ind.name(indexName); - return MultiArrayContraction >(cop, *this, *ind, begin, end); - - } - */ template template @@ -732,6 +759,12 @@ namespace MultiArrayTools return sizeof...(MAOps) + 1; } + template + IndefinitIndexBase* MultiArrayOperation::getLinked(const std::string& name) const + { + return LinkedIndexGetter::getLinked(mArgs, name, nullptr); + } + template void MultiArrayOperation::linkIndicesTo(IndefinitIndexBase* target) const { @@ -744,19 +777,131 @@ namespace MultiArrayTools { mVal = OperationCall:: template callOperation(mOp, mArgs); + std::cout << mVal << std::endl; return mVal; } + /******************************* * MultiArrayContraction * *******************************/ + template + MultiArrayContraction:: + MultiArrayContraction(const ContractOperation& op, + const typename Range::IndexType& runIndex, + const MAOps&... args) : + mOp(op), + mArgs(std::make_tuple(args...)) {} + + // !!!!! + + template + template + MultiArrayOperation,MAOps2...> + MultiArrayContraction::operator()(Operation2& op, const MAOps2&... secs) const + { + return MultiArrayOperation, + MAOps2...>(op, *this, secs...); + } + + template + template + MultiArrayOperation,MAOps2...> + MultiArrayContraction::operator()(const Operation2& op, const MAOps2&... secs) const + { + return MultiArrayOperation, + MAOps2...>(op, *this, secs...); + } + + + template + template + MultiArrayContraction,MAOps2...> + MultiArrayContraction:: + contract(const ContractOperation2& cop, + const std::string& indexName, + const MAOps2&... mao) const + { + typename Range2::IndexType* ind = dynamic_cast( getLinked(indexName) ); + return MultiArrayContraction, + MAOps2...>(cop, *ind, *this, mao...); + } + + template + template + auto MultiArrayContraction::operator+(const MAOp2& sec) + -> decltype(operator()(std::plus(), sec)) + { + return operator()(std::plus(), sec); + } + + template + template + auto MultiArrayContraction::operator-(const MAOp2& sec) + -> decltype(operator()(std::minus(), sec)) + { + return operator()(std::minus(), sec); + } + + template + template + auto MultiArrayContraction::operator*(const MAOp2& sec) + -> decltype(operator()(std::multiplies(), sec)) + { + return operator()(std::multiplies(), sec); + } + + template + template + auto MultiArrayContraction::operator/(const MAOp2& sec) + -> decltype(operator()(std::divides(), sec)) + { + return operator()(std::divides(), sec); + } + + template + size_t MultiArrayContraction::argNum() const + { + return sizeof...(MAOps) + 1; + } + + template + IndefinitIndexBase* MultiArrayContraction::getLinked(const std::string& name) const + { + return LinkedIndexGetter::getLinked(mArgs, name, nullptr); + } + + template + void MultiArrayContraction::linkIndicesTo(IndefinitIndexBase* target) const + { + TupleIndicesLinker::linkTupleIndicesTo(mArgs, target); + } + + + template + const T& MultiArrayContraction::get() const + { + mOp.reset(); + for(mRunIndex.copyPos( mBeginIndex ); mRunIndex.pos() != mEndIndex.pos(); ++mRunIndex){ + OperationCall:: + template callOperation(mOp, mArgs); + //MAO::mOp(std::get<0>(MAO::mArgs).get() ); + } + mOp.endOp(); + std::cout << MAO::mOp() << std::endl; + return mOp(); + } + + + /* template MultiArrayContraction:: MultiArrayContraction(const ContractOperation& cop, const typename Range::IndexType& runIndex, const MAOps&... mao) : - MultiArrayOperation(cop, mao...), + MultiArrayContraction(cop, mao...), mBeginIndex(runIndex), mEndIndex(runIndex), mRunIndex(runIndex) { @@ -772,7 +917,7 @@ namespace MultiArrayTools size_t begin, size_t end, const MAOps&... mao) : - MultiArrayOperation(cop, mao...), + MultiArrayContraction(cop, mao...), mBeginIndex(runIndex), mEndIndex(runIndex), mRunIndex(runIndex) { @@ -793,6 +938,7 @@ namespace MultiArrayTools //MAO::mOp(std::get<0>(MAO::mArgs).get() ); } MAO::mOp.endOp(); + std::cout << MAO::mOp() << std::endl; return MAO::mOp(); - } + }*/ } diff --git a/src/multi_array_operation.h b/src/multi_array_operation.h index be6e6c6..162bedd 100644 --- a/src/multi_array_operation.h +++ b/src/multi_array_operation.h @@ -22,6 +22,7 @@ namespace MultiArrayTools virtual size_t argNum() const = 0; const IndefinitIndexBase& index() const; + virtual IndefinitIndexBase* getLinked(const std::string& name) const = 0; virtual void linkIndicesTo(IndefinitIndexBase* target) const = 0; virtual const T& get() const = 0; @@ -65,7 +66,11 @@ namespace MultiArrayTools template MultiArrayOperationRoot& operator=(const MultiArrayOperation& in); - + + template + MultiArrayOperationRoot& + operator=(const MultiArrayContraction& in); + //template //MultiArrayOperation, MAOps...> //operator()(Operation& op, const MAOps&... secs) const; @@ -122,7 +127,8 @@ namespace MultiArrayTools // set index -> implement !!!!! MultiArrayOperationRoot& operator[](const IndexType& ind); const MultiArrayOperationRoot& operator[](const IndexType& ind) const; - + + virtual IndefinitIndexBase* getLinked(const std::string& name) const override; virtual void linkIndicesTo(IndefinitIndexBase* target) const override; virtual T& get() override; @@ -207,7 +213,8 @@ namespace MultiArrayTools // set index -> implement !!!!! const ConstMultiArrayOperationRoot& operator[](const IndexType& ind) const; - + + virtual IndefinitIndexBase* getLinked(const std::string& name) const override; virtual void linkIndicesTo(IndefinitIndexBase* target) const override; virtual const T& get() const override; @@ -259,17 +266,11 @@ namespace MultiArrayTools operator()(const Operation2& op, const MAOps2&... secs) const; - template - MultiArrayContraction > - contract(const ContractOperation& cop, const std::string& indexName) const; - - /* - template - MultiArrayContraction > + template + MultiArrayContraction,MAOps2...> contract(const ContractOperation& cop, const std::string& indexName, - const typename Range2::IndexType& begin, - const typename Range2::IndexType& end) const; - */ + const MAOps2&... mao) const; + template auto operator+(const MAOp2& sec) -> decltype(operator()(std::plus(), sec)); @@ -284,7 +285,8 @@ namespace MultiArrayTools auto operator/(const MAOp2& sec) -> decltype(operator()(std::divides(), sec)); virtual size_t argNum() const override; - + + virtual IndefinitIndexBase* getLinked(const std::string& name) const override; virtual void linkIndicesTo(IndefinitIndexBase* target) const override; virtual const T& get() const override; @@ -296,6 +298,64 @@ namespace MultiArrayTools OBT mArgs; // include first arg also here !!! }; + template + class MultiArrayContraction : public MultiArrayOperationBase + { + public: + + typedef MultiArrayOperationBase MAOB; + typedef std::tuple OBT; + + MultiArrayContraction(ContractOperation& op, const MAOps&... secs); + MultiArrayContraction(const ContractOperation& op, const MAOps&... secs); + + template + MultiArrayOperation,MAOps2...> + operator()(Operation2& op, const MAOps2&... secs) const; + + template + MultiArrayOperation,MAOps2...> + operator()(const Operation2& op, const MAOps2&... secs) const; + + + template + MultiArrayContraction,MAOps2...> + contract(const ContractOperation2& cop, const std::string& indexName, + const MAOps2&... mao) const; + + + template + auto operator+(const MAOp2& sec) -> decltype(operator()(std::plus(), sec)); + + template + auto operator-(const MAOp2& sec) -> decltype(operator()(std::minus(), sec)); + + template + auto operator*(const MAOp2& sec) -> decltype(operator()(std::multiplies(), sec)); + + template + auto operator/(const MAOp2& sec) -> decltype(operator()(std::divides(), sec)); + + virtual size_t argNum() const override; + + virtual IndefinitIndexBase* getLinked(const std::string& name) const override; + virtual void linkIndicesTo(IndefinitIndexBase* target) const override; + + virtual const T& get() const override; + + protected: + + mutable T mVal; + ContractOperation mOp; + OBT mArgs; // include first arg also here !!! + typename Range::IndexType mBeginIndex; + typename Range::IndexType mEndIndex; + mutable typename Range::IndexType mRunIndex; + + }; + + /* template class MultiArrayContraction : public MultiArrayOperation { @@ -312,7 +372,7 @@ namespace MultiArrayTools size_t begin, size_t end, const MAOps&... mao); - + virtual const T& get() const override; protected: @@ -320,7 +380,7 @@ namespace MultiArrayTools typename Range::IndexType mEndIndex; mutable typename Range::IndexType mRunIndex; }; - + */ } #include "multi_array_operation.cc" diff --git a/src/unit_test.cc b/src/unit_test.cc index ae64fec..8e2b690 100644 --- a/src/unit_test.cc +++ b/src/unit_test.cc @@ -178,7 +178,8 @@ namespace { rb(r1,r2), r3d(r1,r2,r3), ma(r3d, {-5,6,2,1,9,54,27,-7,-13,32,90,-67, - -10,16,-2,101,39,-64,81,-22,14,34,95,-62}) {} + -10,16,-2,101,39,-64,81,-22,14,34,95,-62}), + max(ra, {-5,6,2,1,9,54}){} Range1dAny r1; Range1dAny r2; @@ -187,8 +188,10 @@ namespace { Range2dAny rb; Range3dAny r3d; MultiArray3dAny ma; + MultiArray2dAny max; }; - + + /* TEST_F(OneDimTest, CorrectExtensions) { EXPECT_EQ(ma.size(), 5); @@ -507,7 +510,7 @@ namespace { EXPECT_EQ(sl[j(j1 = 2, j2 = 0)], 14); EXPECT_EQ(sl[j(j1 = 2, j2 = 1)], 34); } - + */ TEST_F(ContractionTest, ContractionWorks) { MultiArray2dAny ma2(ra); @@ -528,6 +531,27 @@ namespace { EXPECT_EQ(ma2[i(i1 = 2, i2 = 1)], -114); } + TEST_F(ContractionTest, ContractionWorks_2) + { + MultiArray2dAny ma2(ra); + + ma2("alpha","gamma") = ma("alpha","beta","gamma").contract(sum(),"beta") + * max("alpha","gamma"); + + auto i = ma2.beginIndex(); + auto i1 = i.template getIndex<0>(); + auto i2 = i.template getIndex<1>(); + + EXPECT_EQ(ma2[i(i1 = 0, i2 = 0)], -275); + EXPECT_EQ(ma2[i(i1 = 0, i2 = 1)], 324); + + EXPECT_EQ(ma2[i(i1 = 1, i2 = 0)], 130); + EXPECT_EQ(ma2[i(i1 = 1, i2 = 1)], 82); + + EXPECT_EQ(ma2[i(i1 = 2, i2 = 0)], 2061); + EXPECT_EQ(ma2[i(i1 = 2, i2 = 1)], -6156); + } + } // end namespace int main(int argc, char** argv)