diff --git a/src/include/dynamic_operation.cc.h b/src/include/dynamic_operation.cc.h index 020ad4e..056628d 100644 --- a/src/include/dynamic_operation.cc.h +++ b/src/include/dynamic_operation.cc.h @@ -4,7 +4,7 @@ namespace MultiArrayTools { template - const T& DynamicOperation::get(const DExtT& pos) const + T DynamicOperation::get(const DExtT& pos) const { return mOp.get(pos.expl()); } diff --git a/src/include/dynamic_operation.h b/src/include/dynamic_operation.h index 4982333..704b3a6 100644 --- a/src/include/dynamic_operation.h +++ b/src/include/dynamic_operation.h @@ -24,7 +24,7 @@ namespace MultiArrayTools DynamicOperationBase& operator=(const DynamicOperationBase& in) = default; DynamicOperationBase& operator=(DynamicOperationBase&& in) = default; - virtual const T& get(const DExtT& pos) const = 0; + virtual T get(const DExtT& pos) const = 0; virtual DynamicOperationBase& set(const DExtT& pos) = 0; virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const = 0; virtual DynamicExpression loop(const DynamicExpression& exp) const = 0; @@ -50,7 +50,7 @@ namespace MultiArrayTools DynamicOperation(const Operation& op) : mOp(op) {} - virtual const T& get(const DExtT& pos) const override final; + virtual T get(const DExtT& pos) const override final; virtual DynamicOperationBase& set(const DExtT& pos) override final; virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const override final; virtual DynamicExpression loop(const DynamicExpression& exp) const override final; @@ -81,8 +81,9 @@ namespace MultiArrayTools DynamicO(const Op& op) : mOp(std::make_shared>(op)) {} template - inline const T& get(const DExtTX& pos) const { return mOp->get(pos.reduce()); } - inline DynamicO& set(const DExtT& pos) { return mOp->set(pos); } + inline T get(const DExtTX& pos) const { return mOp->get(pos.reduce()); } + template + inline DynamicO& set(const DExtTX& pos) { mOp->set(pos.reduce()); return *this; } inline DExtT rootSteps(std::intptr_t iPtrNum = 0) const { return mOp->rootSteps(iPtrNum); } inline DynamicExpression loop(const DynamicExpression& exp) const { return mOp->loop(exp); } inline const T* data() const { return mOp->data(); } diff --git a/src/include/multi_array_operation.cc.h b/src/include/multi_array_operation.cc.h index 5e1199c..337ce0f 100644 --- a/src/include/multi_array_operation.cc.h +++ b/src/include/multi_array_operation.cc.h @@ -753,7 +753,8 @@ namespace MultiArrayTools template inline Operation& Operation::set(ET pos) { - PackNum::setOpPos(mOps,pos); + typedef std::tuple OpTuple; + PackNum::template setOpPos(mOps,pos); return *this; } diff --git a/src/include/multi_array_operation.h b/src/include/multi_array_operation.h index 5ba3b1d..2301996 100644 --- a/src/include/multi_array_operation.h +++ b/src/include/multi_array_operation.h @@ -529,7 +529,8 @@ namespace MultiArrayTools template auto loop(Expr exp) const -> decltype(PackNum::mkLoop( mOps, exp)); - + + T* data() const { assert(0); return nullptr; } }; namespace diff --git a/src/include/pack_num.h b/src/include/pack_num.h index 2797243..aa7f60a 100644 --- a/src/include/pack_num.h +++ b/src/include/pack_num.h @@ -122,7 +122,7 @@ namespace MultiArrayHelper static_assert(LAST > NextOpType::SIZE, "inconsistent array positions"); static constexpr size_t NEXT = LAST - NextOpType::SIZE; std::get( ot ).set( Getter::template getX( et ) ); - PackNum::template setOpPos(ot, et); + PackNum::template setOpPos(ot, et); } }; @@ -205,9 +205,9 @@ namespace MultiArrayHelper template static inline void setOpPos(OpTuple& ot, const ETuple& et) { - typedef typename std::remove_reference(et))>::type NextOpType; - static_assert(LAST > NextOpType::SIZE, "inconsistent array positions"); + typedef typename std::remove_reference(ot))>::type NextOpType; static constexpr size_t NEXT = LAST - NextOpType::SIZE; + static_assert(NEXT == 0, "inconsistent array positions"); std::get<0>( ot ).set( Getter::template getX( et ) ); } diff --git a/src/tests/op4_unit_test.cc b/src/tests/op4_unit_test.cc index 4623943..7c9a973 100644 --- a/src/tests/op4_unit_test.cc +++ b/src/tests/op4_unit_test.cc @@ -104,10 +104,13 @@ namespace (*di4)({imap["i2_1"],imap["i3_1"],imap["i4_1"],imap["i4_2"]}); auto resx1 = res1; + auto resx2 = res1; + auto resx3 = res1; res1(i1,di4) = ma1(i1,di1) * ma2(i1,di2); - resx1(i1,di4) = mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2)); - + resx2(i1,di4) = mkDynOp(ma1(i1,di1) * ma2(i1,di2)); + resx3(i1,di4) = mkDynOp(mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2))); + auto i2_1 = imap.at("i2_1"); auto i2_2 = imap.at("i2_2"); auto i3_1 = imap.at("i3_1"); @@ -123,10 +126,15 @@ namespace const size_t jr = (((ii1*i2_1->max() + ii2_1)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1)*i4_2->max() + ii4_2; const size_t j1 = (((ii1*i2_1->max() + ii2_1)*i2_2->max() + ii2_2)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1; const size_t j2 = ((ii1*i3_1->max() + ii3_1)*i3_1->max() + ii3_1)*i4_2->max() + ii4_2; - auto resx = xround(res1.vdata()[jr]); - //std::cout << resx << " "; - auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]); - EXPECT_EQ( resx, x12 ); + auto resv = xround(res1.vdata()[jr]); + auto resx1v = xround(resx1.vdata()[jr]); + auto resx2v = xround(resx2.vdata()[jr]); + auto resx3v = xround(resx3.vdata()[jr]); + auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]); + EXPECT_EQ( resv, x12 ); + EXPECT_EQ( resx1v, x12 ); + EXPECT_EQ( resx2v, x12 ); + EXPECT_EQ( resx3v, x12 ); } } }