diff --git a/src/include/high_level_operation.cc.h b/src/include/high_level_operation.cc.h index 02dc46d..db97d29 100644 --- a/src/include/high_level_operation.cc.h +++ b/src/include/high_level_operation.cc.h @@ -74,6 +74,44 @@ namespace MultiArrayTools return &mOp; } + template + auto HighLevelOpRoot::vget() + -> VOP* + { + return nullptr; + } + + template + HighLevelOpValue::HighLevelOpValue(const VOP& op) : mOp(op) {} + + template + bool HighLevelOpValue::root() const + { + return true; + } + + template + template + auto HighLevelOpValue::xcreate(const std::shared_ptr&... inds) + -> typename B::template RetT + { + assert(0); + return typename B::template RetT(); + } + + template + ROP* HighLevelOpValue::get() + { + return nullptr; + } + + template + auto HighLevelOpValue::vget() + -> VOP* + { + return &mOp; + } + namespace { template @@ -105,10 +143,18 @@ namespace MultiArrayTools (res, in, inds..., op, ops..., dop, dops...); } else { - auto& op = *inn->get(); - typedef typename std::remove_reference::type OP; - Create::template cx::template ccx::template cccx - (res, in, inds..., op, ops..., dops...); + auto op = inn->get(); + auto vop = inn->vget(); + typedef typename std::remove_reference::type OP; + typedef typename std::remove_reference::type VOP; + if(op != nullptr){ + Create::template cx::template ccx::template cccx + (res, in, inds..., *op, ops..., dops...); + } + else { + Create::template cx::template ccx::template cccx + (res, in, inds..., *vop, ops..., dops...); + } } } }; @@ -142,8 +188,14 @@ namespace MultiArrayTools res.appendOuterM(dop.op,dops.op...); } else { - auto& op = *inn->get(); - res.op = mkDynOutOp(mkFOp(op,ops...), inds...); + auto op = inn->get(); + auto vop = inn->vget(); + if(op != nullptr){ + res.op = mkDynOutOp(mkFOp(*op,ops...), inds...); + } + else { + res.op = mkDynOutOp(mkFOp(*vop,ops...), inds...); + } res.appendOuterM(dops.op...); } } @@ -167,6 +219,14 @@ namespace MultiArrayTools assert(0); return nullptr; } + + template + auto HighLevelOp::vget() + -> VOP* + { + assert(0); + return nullptr; + } template template @@ -394,6 +454,14 @@ namespace MultiArrayTools return HighLevelOpHolder(std::make_shared>( op ) ); } + template + HighLevelOpHolder mkHLOV(double val) + { + return HighLevelOpHolder(std::make_shared> + ( OperationValue(val) ) ); + } + + #define SP " " #define regFunc1(fff) template \ HighLevelOpHolder hl_##fff (const HighLevelOpHolder& in) \ @@ -404,116 +472,5 @@ namespace MultiArrayTools #undef regFunc1 #undef SP - /* - template - template - inline void SetLInds::mkLIT(const ITuple& itp, const std::shared_ptr& di) - { - constexpr size_t NN = std::tuple_size::value-N-1; - const size_t nn = di->dim()-N-1; - typedef typename std::remove_reference(itp))>::type T; - std::get(itp) = - std::dynamic_pointer_cast(di->get(nn))->getIndex(); - SetLInds::mkLIT(itp, di); - } - - template - template - template - inline void SetLInds::xx:: - assign(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr&... is) - { - SetLInds::template xx::assign(tar, args..., itp, std::get(itp), is...); - } - - template - template - template - inline void SetLInds::xx:: - plus(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr&... is) - { - SetLInds::template xx::plus(tar, args..., itp, std::get(itp), is...); - } - - //template <> - template - inline void SetLInds<0>::mkLIT(const ITuple& itp, const std::shared_ptr& di) - { - constexpr size_t NN = std::tuple_size::value-1; - const size_t nn = di->dim()-1; - typedef typename std::remove_reference(itp))>::type T; - std::get(itp) = - std::dynamic_pointer_cast(di->get(nn))->getIndex(); - } - - //template <> - template - template - inline void SetLInds<0>::xx:: - assign(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr&... is) - { - tar.assign(args..., std::get<0>(itp), is...); - } - - //template <> - template - template - inline void SetLInds<0>::xx:: - plus(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr&... is) - { - tar.plus(args..., std::get<0>(itp), is...); - } - - template - size_t INDS::CallHLOpBase::depth() const - { - return mDepth; - } - - - template - template - void INDS::CallHLOp:: - assign(HighLevelOpHolder& target, const HighLevelOpHolder& source, - const std::shared_ptr&... is, - const std::shared_ptr& di) const - { - auto ip = di->get(di->dim() - this->depth()); - auto iregn = ip->regN(); - if(iregn.type >= 0 and iregn.depth > sizeof...(LIndices)){ - sNext[iregn.type]->assign(target, source, is..., di); - } - else { - ITuple itp; - SetLInds::mkLIT(itp,di); - auto mi = mkIndex(is...,mkSubSpaceX(di, di->dim() - this->depth())); - SetLInds:: - template xx,ITuple,HighLevelOpHolder,decltype(mi)>:: - assign(target, source, mi, itp); - } - } - - template - template - void INDS::CallHLOp:: - plus(HighLevelOpHolder& target, const HighLevelOpHolder& source, - const std::shared_ptr&... is, - const std::shared_ptr& di) const - { - auto ip = di->get(di->dim() - this->depth()); - auto iregn = ip->regN(); - if(iregn.type >= 0 and iregn.depth > sizeof...(LIndices)){ - sNext[iregn.type]->plus(target, source, is..., di); - } - else { - ITuple itp; - SetLInds::mkLIT(itp,di); - auto mi = mkIndex(is...,mkSubSpaceX(di, di->dim() - this->depth())); - SetLInds:: - template xx,ITuple,HighLevelOpHolder,decltype(mi)>:: - plus(target, source, mi, itp); - } - } - */ } diff --git a/src/include/high_level_operation.h b/src/include/high_level_operation.h index 68a823a..fa1ffc1 100644 --- a/src/include/high_level_operation.h +++ b/src/include/high_level_operation.h @@ -27,7 +27,9 @@ namespace MultiArrayTools class HighLevelOpBase { public: - + + typedef OperationValue VOP; + template struct RetT { @@ -59,7 +61,8 @@ namespace MultiArrayTools #undef reg_ind3 virtual ROP* get() = 0; - + virtual VOP* vget() = 0; + }; template @@ -67,7 +70,8 @@ namespace MultiArrayTools { private: typedef HighLevelOpBase B; - + typedef typename B::VOP VOP; + template typename B::template RetT xcreate(const std::shared_ptr&... inds); @@ -91,7 +95,7 @@ namespace MultiArrayTools #include "hl_reg_ind.h" virtual ROP* get() override final; - + virtual VOP* vget() override final; }; @@ -99,6 +103,30 @@ namespace MultiArrayTools extern template class HighLevelOpBase; extern template class HighLevelOpRoot; extern template class HighLevelOpRoot; + + template + class HighLevelOpValue : public HighLevelOpBase + { + private: + typedef HighLevelOpBase B; + typedef typename B::VOP VOP; + + template + typename B::template RetT xcreate(const std::shared_ptr&... inds); + + VOP mOp; + public: + + HighLevelOpValue(const VOP& vop); + + virtual bool root() const override final; + +#include "hl_reg_ind.h" + + virtual ROP* get() override final; + virtual VOP* vget() override final; + + }; template auto mkFOp(const Ops&... ops) @@ -113,6 +141,7 @@ namespace MultiArrayTools { public: typedef HighLevelOpBase B; + typedef typename B::VOP VOP; private: std::array>,N> mIn; @@ -127,6 +156,7 @@ namespace MultiArrayTools virtual bool root() const override final; virtual ROP* get() override final; + virtual VOP* vget() override final; #include "hl_reg_ind.h" @@ -211,6 +241,14 @@ namespace MultiArrayTools template HighLevelOpHolder mkHLO(const ROP& op); + template + HighLevelOpHolder mkHLOV(double val); + + extern template HighLevelOpHolder mkHLO(const OpCD& op); + extern template HighLevelOpHolder mkHLO(const OpD& op); + extern template HighLevelOpHolder mkHLOV(double val); + extern template HighLevelOpHolder mkHLOV(double val); + #define regFunc1(fff) template \ HighLevelOpHolder hl_##fff (const HighLevelOpHolder& in); #include "extensions/math.h" diff --git a/src/include/pack_num.h b/src/include/pack_num.h index aa7f60a..5a185c5 100644 --- a/src/include/pack_num.h +++ b/src/include/pack_num.h @@ -119,7 +119,7 @@ namespace MultiArrayHelper static inline void setOpPos(OpTuple& ot, const ETuple& et) { typedef typename std::remove_reference(ot))>::type NextOpType; - static_assert(LAST > NextOpType::SIZE, "inconsistent array positions"); + static_assert(LAST >= NextOpType::SIZE, "inconsistent array positions"); static constexpr size_t NEXT = LAST - NextOpType::SIZE; std::get( ot ).set( Getter::template getX( et ) ); PackNum::template setOpPos(ot, et); diff --git a/src/lib/high_level_operation.cc b/src/lib/high_level_operation.cc index 50a0299..6f9f706 100644 --- a/src/lib/high_level_operation.cc +++ b/src/lib/high_level_operation.cc @@ -32,4 +32,9 @@ namespace MultiArrayTools template class HighLevelOpRoot; template class HighLevelOpRoot; + template HighLevelOpHolder mkHLO(const OpCD& op); + template HighLevelOpHolder mkHLO(const OpD& op); + template HighLevelOpHolder mkHLOV(double val); + template HighLevelOpHolder mkHLOV(double val); + } diff --git a/src/tests/op4_unit_test.cc b/src/tests/op4_unit_test.cc index ae1e7ae..2cd67f5 100644 --- a/src/tests/op4_unit_test.cc +++ b/src/tests/op4_unit_test.cc @@ -307,8 +307,10 @@ namespace auto hop2 = hl_exp(hop1); auto hop4 = hop3 * hop2; auto hopr = mkHLO(resx4(i1,di4)); + auto hop5 = mkHLOV(1.); + auto hop6 = hop4 - hop5; //hopr.assign( hop4, mi, ic_1, ic_2 ); - hopr.xassign( hop4, di4, i1 ); + hopr.xassign( hop6, di4, i1 ); auto i2_1 = imap.at("i2_1"); auto i2_2 = imap.at("i2_2"); @@ -331,11 +333,12 @@ namespace auto resx3v = xround(resx3.vdata()[jr]); auto resx4v = xround(resx4.vdata()[jr]); auto x12 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2])); + auto x121 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2])-1.); EXPECT_EQ( resv, x12 ); EXPECT_EQ( resx1v, x12 ); EXPECT_EQ( resx2v, x12 ); EXPECT_EQ( resx3v, x12 ); - EXPECT_EQ( resx4v, x12 ); + EXPECT_EQ( resx4v, x121 ); } } }