diff --git a/src/tests/op4_unit_test.cc b/src/tests/op4_unit_test.cc index c1ab03a..b3f889c 100644 --- a/src/tests/op4_unit_test.cc +++ b/src/tests/op4_unit_test.cc @@ -280,10 +280,7 @@ namespace auto di4 = getIndex(dr4); auto di4a = getIndex(dr4a); - //(*di1)({imap["i2_1"],imap["i2_2"],imap["i3_1"],imap["i4_1"]}); - //(*di2)({imap["i3_1"],imap["i3_1"],imap["i4_2"]}); - //(*di4)({imap["i2_1"],imap["i3_1"],imap["i4_1"],imap["i4_2"]}); - (*di1)({"ia_1","ia_2","ib_1","ic_1"}); + (*di1)({"ia_1","ia_2","ib_1","ic_1"}); (*di2)({"ib_1","ib_1","ic_2"}); (*di4)({"ia_1","ib_1","ic_1","ic_2"}); (*di4a)(svec({"ia_1","ib_1"})); @@ -299,7 +296,6 @@ namespace resx2(i1,di4) = mkDynOp(ma1(i1,di1) * exp(ma2(i1,di2))); resx3(i1,di4) = mkDynOp(mkDynOp(ma1(i1,di1)) * mkDynOp(exp(mkDynOp(ma2(i1,di2))))); - //auto xx = std::make_shared(resx4); auto xx = mkArrayPtr(nullr()); auto mi = mkMIndex(i1,di4a); @@ -307,20 +303,9 @@ namespace auto hop3 = mkHLO(ma1(i1,di1)); auto hop2 = exp(hop1); auto hop4 = hop3 * hop2; - - auto opr = resx4(i1,di4); - auto loop = mkPILoop - ( [&opr,&hop4,&xx,&ic_1,&ic_2,this](){ - auto hop4x = hop4; - auto dop2 = hop4x.create(ic_1,ic_2); - auto gexp = mkDynOp1(mkMOp(dop2.outer,dop2.op)); - auto xloop = mkILoop(std::make_tuple(*dop2.op.data()->mOp), std::make_tuple(ic_1, ic_2), - std::make_tuple(xx), - std::make_tuple(opr.assign( *dop2.op.data()->mOp, mkMIndex(ic_1, ic_2) )), - std::array({1}), std::array({0})); - return mkGetExpr(gexp, xloop); }); - mi->pifor(1,loop)(); - + auto hopr = mkHLO(resx4(i1,di4)); + hopr.assign( hop4, mi, ic_1, ic_2 ); + auto i2_1 = imap.at("i2_1"); auto i2_2 = imap.at("i2_2"); auto i3_1 = imap.at("i3_1"); diff --git a/src/tests/op4_unit_test.h b/src/tests/op4_unit_test.h index 61f7eda..54d8b23 100644 --- a/src/tests/op4_unit_test.h +++ b/src/tests/op4_unit_test.h @@ -13,6 +13,7 @@ DynamicO mkDynOp1(const Op& op) return DynamicO(op); } +template class HighLevelOpBase { public: @@ -66,65 +67,37 @@ public: virtual RetT create(const std::shared_ptr ind1, const std::shared_ptr ind2) = 0; - virtual const OperationRoot* get1() const = 0; - virtual const OperationRoot* get2() const = 0; + //virtual const OperationRoot* get2() const = 0; + virtual const ROP* get() const = 0; }; -template -struct Fwd -{ - template - static inline const O1* fwd(const O2* in) - { - assert(0); - return nullptr; - } -}; - -template <> -struct Fwd -{ - template - static inline const O1* fwd(const O2* in) - { - return in; - } -}; - -template -class HighLevelOpRoot : public HighLevelOpBase +template +class HighLevelOpRoot : public HighLevelOpBase { private: - typedef OperationRoot OType1; - typedef OperationRoot OType2; - typedef HighLevelOpBase B; + typedef HighLevelOpBase B; - OR mOp; + ROP mOp; public: - HighLevelOpRoot(const OR& op) : mOp(op) {} + HighLevelOpRoot(const ROP& op) : mOp(op) {} virtual bool root() const override final { return true; } - virtual B::RetT create(const std::shared_ptr ind1, - const std::shared_ptr ind2) override final + virtual typename B::RetT create(const std::shared_ptr ind1, + const std::shared_ptr ind2) override final { assert(0); - return B::RetT(); + return typename B::RetT(); } - virtual const OType1* get1() const override final + virtual const ROP* get() const override final { - return Fwd::value>::template fwd(&mOp); - } - - virtual const OType2* get2() const override final - { - return Fwd::value>::template fwd(&mOp); + return &mOp; } @@ -143,13 +116,13 @@ struct Create template struct cx { - template + template struct ccx { template static inline void - cccx(HighLevelOpBase::RetT& res, - const std::array,M>& in, + cccx(typename HighLevelOpBase::RetT& res, + const std::array>,M>& in, const std::shared_ptr&... inds, const OPs&... ops, const DOPs&... dops) @@ -161,13 +134,13 @@ struct Create auto op = *dop.op.data()->mOp; typedef decltype(op) OP; res.appendOuter(dop); - Create::template cx::template ccx::template cccx + Create::template cx::template ccx::template cccx (res, in, inds..., op, ops..., dop, dops...); } else { - auto& op = *inn->get2(); + auto& op = *inn->get(); typedef typename std::remove_reference::type OP; - Create::template cx::template ccx::template cccx + Create::template cx::template ccx::template cccx (res, in, inds..., op, ops..., dops...); } } @@ -181,13 +154,13 @@ struct Create<0> template struct cx { - template + template struct ccx { template static inline void - cccx(HighLevelOpBase::RetT& res, - const std::array,M>& in, + cccx(typename HighLevelOpBase::RetT& res, + const std::array>,M>& in, const std::shared_ptr&... inds, const OPs&... ops, const DOPs&... dops) @@ -201,7 +174,7 @@ struct Create<0> res.appendOuterM(dop.op,dops.op...); } else { - auto& op = *inn->get2(); + auto& op = *inn->get(); res.op = mkDynOutOp(mkFOp(op,ops...), inds...); res.appendOuterM(dops.op...); } @@ -210,42 +183,41 @@ struct Create<0> }; }; -template -class HighLevelOp : public HighLevelOpBase +template +class HighLevelOp : public HighLevelOpBase { private: - std::array,N> mIn; + std::array>,N> mIn; public: - typedef HighLevelOpBase B; + typedef HighLevelOpBase B; - HighLevelOp(std::array,N> in) : mIn(in) {} + HighLevelOp(std::array>,N> in) : mIn(in) {} virtual bool root() const override final { return false; } - virtual const OperationRoot* get1() const override final - { assert(0); return nullptr; } - virtual const OperationRoot* get2() const override final + virtual const ROP* get() const override final { assert(0); return nullptr; } - virtual B::RetT create(const std::shared_ptr ind1, + virtual typename B::RetT create(const std::shared_ptr ind1, const std::shared_ptr ind2) override final { - B::RetT res; - Create::template cx::template ccx::template cccx(res,mIn,ind1,ind2); + typename B::RetT res; + Create::template cx::template ccx::template cccx(res,mIn,ind1,ind2); return res; } }; +template class HighLevelOpHolder { private: - std::shared_ptr mOp; + std::shared_ptr> mOp; public: HighLevelOpHolder() = default; @@ -254,34 +226,104 @@ public: HighLevelOpHolder& operator=(const HighLevelOpHolder& in) = default; HighLevelOpHolder& operator=(HighLevelOpHolder&& in) = default; - HighLevelOpHolder(const std::shared_ptr& op) : mOp(op) {} + HighLevelOpHolder(const std::shared_ptr>& op) : mOp(op) {} bool root() const { return mOp->root(); } auto create(const std::shared_ptr ind1, const std::shared_ptr ind2) const { return mOp->create(ind1,ind2); } - auto get1() const { return mOp->get1(); } - auto get2() const { return mOp->get2(); } + auto get() const { return mOp->get(); } - std::shared_ptr op() const { return mOp; } + std::shared_ptr> op() const { return mOp; } - HighLevelOpHolder operator*(const HighLevelOpHolder in) const + HighLevelOpHolder operator*(const HighLevelOpHolder& in) const { - return HighLevelOpHolder - ( std::make_shared,2>> - ( std::array,2>({mOp, in.mOp}) ) ); + return HighLevelOpHolder + ( std::make_shared,2>> + ( std::array>,2>({mOp, in.mOp}) ) ); + } + + HighLevelOpHolder operator+(const HighLevelOpHolder& in) const + { + return HighLevelOpHolder + ( std::make_shared,2>> + ( std::array>,2>({mOp, in.mOp}) ) ); + } + + HighLevelOpHolder operator-(const HighLevelOpHolder& in) const + { + return HighLevelOpHolder + ( std::make_shared,2>> + ( std::array>,2>({mOp, in.mOp}) ) ); + } + + HighLevelOpHolder operator/(const HighLevelOpHolder& in) const + { + return HighLevelOpHolder + ( std::make_shared,2>> + ( std::array>,2>({mOp, in.mOp}) ) ); + } + + template + HighLevelOpHolder& assign(const HighLevelOpHolder& in, + const std::shared_ptr& mi, + const std::shared_ptr&... inds) + { + auto xx = mkArrayPtr(nullr()); + auto& opr = *mOp->get(); + auto loop = mkPILoop + ( [&opr,&in,&xx,&inds...,this](){ + auto inx = in; + auto dop = inx.create(inds...); + auto gexp = mkDynOp1(mkMOp(dop.outer,dop.op)); + auto xloop = mkILoop(std::make_tuple(*dop.op.data()->mOp), + std::make_tuple(inds...), + std::make_tuple(xx), + std::make_tuple(opr.assign( *dop.op.data()->mOp, + mkMIndex(inds...) )), + std::array({1}), std::array({0})); + return mkGetExpr(gexp, xloop); }); + mi->pifor(1,loop)(); + return *this; + } + + template + HighLevelOpHolder& plus(const HighLevelOpHolder& in, + const std::shared_ptr& mi, + const std::shared_ptr&... inds) + { + auto xx = mkArrayPtr(nullr()); + auto& opr = *mOp->get(); + auto loop = mkPILoop + ( [&opr,&in,&xx,&inds...,this](){ + auto inx = in; + auto dop = inx.create(inds...); + auto gexp = mkDynOp1(mkMOp(dop.outer,dop.op)); + auto xloop = mkILoop(std::make_tuple(*dop.op.data()->mOp), + std::make_tuple(inds...), + std::make_tuple(xx), + std::make_tuple(opr.plus( *dop.op.data()->mOp, + mkMIndex(inds...) )), + std::array({1}), std::array({0})); + return mkGetExpr(gexp, xloop); }); + mi->pifor(1,loop)(); + return *this; } }; -template -HighLevelOpHolder mkHLO(const OR& op) +typedef OperationRoot OpType1; + +template +HighLevelOpHolder mkHLO(const ROP& op) { - return HighLevelOpHolder(std::make_shared>( op ) ); + return HighLevelOpHolder(std::make_shared>( op ) ); } -HighLevelOpHolder exp(const HighLevelOpHolder& in) +template +HighLevelOpHolder exp(const HighLevelOpHolder& in) { - return HighLevelOpHolder( std::make_shared,1>> - ( std::array,1>( {in.op()} ) ) ); + return HighLevelOpHolder + ( std::make_shared,1>> + ( std::array>,1>( {in.op()} ) ) ); }