allow value operations in high level operations

This commit is contained in:
Christian Zimmermann 2020-09-20 13:37:44 +02:00
parent 5a309afac6
commit 995b16b51d
5 changed files with 127 additions and 124 deletions

View file

@ -74,6 +74,44 @@ namespace MultiArrayTools
return &mOp; return &mOp;
} }
template <class ROP>
auto HighLevelOpRoot<ROP>::vget()
-> VOP*
{
return nullptr;
}
template <class ROP>
HighLevelOpValue<ROP>::HighLevelOpValue(const VOP& op) : mOp(op) {}
template <class ROP>
bool HighLevelOpValue<ROP>::root() const
{
return true;
}
template <class ROP>
template <class... Inds>
auto HighLevelOpValue<ROP>::xcreate(const std::shared_ptr<Inds>&... inds)
-> typename B::template RetT<Inds...>
{
assert(0);
return typename B::template RetT<Inds...>();
}
template <class ROP>
ROP* HighLevelOpValue<ROP>::get()
{
return nullptr;
}
template <class ROP>
auto HighLevelOpValue<ROP>::vget()
-> VOP*
{
return &mOp;
}
namespace namespace
{ {
template <size_t N> template <size_t N>
@ -105,10 +143,18 @@ namespace MultiArrayTools
(res, in, inds..., op, ops..., dop, dops...); (res, in, inds..., op, ops..., dop, dops...);
} }
else { else {
auto& op = *inn->get(); auto op = inn->get();
typedef typename std::remove_reference<decltype(op)>::type OP; auto vop = inn->vget();
Create<N-1>::template cx<Indices...>::template ccx<ROP,OpF,OP,OPs...>::template cccx<M> typedef typename std::remove_reference<decltype(*op)>::type OP;
(res, in, inds..., op, ops..., dops...); typedef typename std::remove_reference<decltype(*vop)>::type VOP;
if(op != nullptr){
Create<N-1>::template cx<Indices...>::template ccx<ROP,OpF,OP,OPs...>::template cccx<M>
(res, in, inds..., *op, ops..., dops...);
}
else {
Create<N-1>::template cx<Indices...>::template ccx<ROP,OpF,VOP,OPs...>::template cccx<M>
(res, in, inds..., *vop, ops..., dops...);
}
} }
} }
}; };
@ -142,8 +188,14 @@ namespace MultiArrayTools
res.appendOuterM(dop.op,dops.op...); res.appendOuterM(dop.op,dops.op...);
} }
else { else {
auto& op = *inn->get(); auto op = inn->get();
res.op = mkDynOutOp(mkFOp<OpF>(op,ops...), inds...); auto vop = inn->vget();
if(op != nullptr){
res.op = mkDynOutOp(mkFOp<OpF>(*op,ops...), inds...);
}
else {
res.op = mkDynOutOp(mkFOp<OpF>(*vop,ops...), inds...);
}
res.appendOuterM(dops.op...); res.appendOuterM(dops.op...);
} }
} }
@ -167,6 +219,14 @@ namespace MultiArrayTools
assert(0); assert(0);
return nullptr; return nullptr;
} }
template <class ROP, class OpF, size_t N>
auto HighLevelOp<ROP,OpF,N>::vget()
-> VOP*
{
assert(0);
return nullptr;
}
template <class ROP, class OpF, size_t N> template <class ROP, class OpF, size_t N>
template <class... Inds> template <class... Inds>
@ -394,6 +454,14 @@ namespace MultiArrayTools
return HighLevelOpHolder<ROP>(std::make_shared<HighLevelOpRoot<ROP>>( op ) ); return HighLevelOpHolder<ROP>(std::make_shared<HighLevelOpRoot<ROP>>( op ) );
} }
template <class ROP>
HighLevelOpHolder<ROP> mkHLOV(double val)
{
return HighLevelOpHolder<ROP>(std::make_shared<HighLevelOpValue<ROP>>
( OperationValue<double>(val) ) );
}
#define SP " " #define SP " "
#define regFunc1(fff) template <class ROP> \ #define regFunc1(fff) template <class ROP> \
HighLevelOpHolder<ROP> hl_##fff (const HighLevelOpHolder<ROP>& in) \ HighLevelOpHolder<ROP> hl_##fff (const HighLevelOpHolder<ROP>& in) \
@ -404,116 +472,5 @@ namespace MultiArrayTools
#undef regFunc1 #undef regFunc1
#undef SP #undef SP
/*
template <size_t N>
template <class ITuple>
inline void SetLInds<N>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di)
{
constexpr size_t NN = std::tuple_size<ITuple>::value-N-1;
const size_t nn = di->dim()-N-1;
typedef typename std::remove_reference<decltype(*std::get<NN>(itp))>::type T;
std::get<NN>(itp) =
std::dynamic_pointer_cast<T>(di->get(nn))->getIndex();
SetLInds<N-1>::mkLIT(itp, di);
}
template <size_t N>
template <class Tar, class ITp, typename... Args>
template <class... Is>
inline void SetLInds<N>::xx<Tar,ITp,Args...>::
assign(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
{
SetLInds<N-1>::template xx<ITp,Args...>::assign(tar, args..., itp, std::get<N>(itp), is...);
}
template <size_t N>
template <class Tar, class ITp, typename... Args>
template <class... Is>
inline void SetLInds<N>::xx<Tar,ITp,Args...>::
plus(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
{
SetLInds<N-1>::template xx<ITp,Args...>::plus(tar, args..., itp, std::get<N>(itp), is...);
}
//template <>
template <class ITuple>
inline void SetLInds<0>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di)
{
constexpr size_t NN = std::tuple_size<ITuple>::value-1;
const size_t nn = di->dim()-1;
typedef typename std::remove_reference<decltype(*std::get<NN>(itp))>::type T;
std::get<NN>(itp) =
std::dynamic_pointer_cast<T>(di->get(nn))->getIndex();
}
//template <>
template <class Tar, class ITp, typename... Args>
template <class... Is>
inline void SetLInds<0>::xx<Tar,ITp,Args...>::
assign(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
{
tar.assign(args..., std::get<0>(itp), is...);
}
//template <>
template <class Tar, class ITp, typename... Args>
template <class... Is>
inline void SetLInds<0>::xx<Tar,ITp,Args...>::
plus(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
{
tar.plus(args..., std::get<0>(itp), is...);
}
template <class ROP, class... Indices>
size_t INDS<ROP,Indices...>::CallHLOpBase::depth() const
{
return mDepth;
}
template <class ROP, class... Indices>
template <class... LIndices>
void INDS<ROP,Indices...>::CallHLOp<LIndices...>::
assign(HighLevelOpHolder<ROP>& target, const HighLevelOpHolder<ROP>& source,
const std::shared_ptr<Indices>&... is,
const std::shared_ptr<DynamicIndex>& di) const
{
auto ip = di->get(di->dim() - this->depth());
auto iregn = ip->regN();
if(iregn.type >= 0 and iregn.depth > sizeof...(LIndices)){
sNext[iregn.type]->assign(target, source, is..., di);
}
else {
ITuple itp;
SetLInds<sizeof...(LIndices)-1>::mkLIT(itp,di);
auto mi = mkIndex(is...,mkSubSpaceX(di, di->dim() - this->depth()));
SetLInds<sizeof...(LIndices)-1>::
template xx<HighLevelOpHolder<ROP>,ITuple,HighLevelOpHolder<ROP>,decltype(mi)>::
assign(target, source, mi, itp);
}
}
template <class ROP, class... Indices>
template <class... LIndices>
void INDS<ROP,Indices...>::CallHLOp<LIndices...>::
plus(HighLevelOpHolder<ROP>& target, const HighLevelOpHolder<ROP>& source,
const std::shared_ptr<Indices>&... is,
const std::shared_ptr<DynamicIndex>& di) const
{
auto ip = di->get(di->dim() - this->depth());
auto iregn = ip->regN();
if(iregn.type >= 0 and iregn.depth > sizeof...(LIndices)){
sNext[iregn.type]->plus(target, source, is..., di);
}
else {
ITuple itp;
SetLInds<sizeof...(LIndices)-1>::mkLIT(itp,di);
auto mi = mkIndex(is...,mkSubSpaceX(di, di->dim() - this->depth()));
SetLInds<sizeof...(LIndices)-1>::
template xx<HighLevelOpHolder<ROP>,ITuple,HighLevelOpHolder<ROP>,decltype(mi)>::
plus(target, source, mi, itp);
}
}
*/
} }

View file

@ -27,7 +27,9 @@ namespace MultiArrayTools
class HighLevelOpBase class HighLevelOpBase
{ {
public: public:
typedef OperationValue<double> VOP;
template <class... Indices> template <class... Indices>
struct RetT struct RetT
{ {
@ -59,7 +61,8 @@ namespace MultiArrayTools
#undef reg_ind3 #undef reg_ind3
virtual ROP* get() = 0; virtual ROP* get() = 0;
virtual VOP* vget() = 0;
}; };
template <class ROP> template <class ROP>
@ -67,7 +70,8 @@ namespace MultiArrayTools
{ {
private: private:
typedef HighLevelOpBase<ROP> B; typedef HighLevelOpBase<ROP> B;
typedef typename B::VOP VOP;
template <class... Inds> template <class... Inds>
typename B::template RetT<Inds...> xcreate(const std::shared_ptr<Inds>&... inds); typename B::template RetT<Inds...> xcreate(const std::shared_ptr<Inds>&... inds);
@ -91,7 +95,7 @@ namespace MultiArrayTools
#include "hl_reg_ind.h" #include "hl_reg_ind.h"
virtual ROP* get() override final; virtual ROP* get() override final;
virtual VOP* vget() override final;
}; };
@ -99,6 +103,30 @@ namespace MultiArrayTools
extern template class HighLevelOpBase<OpD>; extern template class HighLevelOpBase<OpD>;
extern template class HighLevelOpRoot<OpCD>; extern template class HighLevelOpRoot<OpCD>;
extern template class HighLevelOpRoot<OpD>; extern template class HighLevelOpRoot<OpD>;
template <class ROP>
class HighLevelOpValue : public HighLevelOpBase<ROP>
{
private:
typedef HighLevelOpBase<ROP> B;
typedef typename B::VOP VOP;
template <class... Inds>
typename B::template RetT<Inds...> xcreate(const std::shared_ptr<Inds>&... inds);
VOP mOp;
public:
HighLevelOpValue(const VOP& vop);
virtual bool root() const override final;
#include "hl_reg_ind.h"
virtual ROP* get() override final;
virtual VOP* vget() override final;
};
template <class OpF, class... Ops> template <class OpF, class... Ops>
auto mkFOp(const Ops&... ops) auto mkFOp(const Ops&... ops)
@ -113,6 +141,7 @@ namespace MultiArrayTools
{ {
public: public:
typedef HighLevelOpBase<ROP> B; typedef HighLevelOpBase<ROP> B;
typedef typename B::VOP VOP;
private: private:
std::array<std::shared_ptr<HighLevelOpBase<ROP>>,N> mIn; std::array<std::shared_ptr<HighLevelOpBase<ROP>>,N> mIn;
@ -127,6 +156,7 @@ namespace MultiArrayTools
virtual bool root() const override final; virtual bool root() const override final;
virtual ROP* get() override final; virtual ROP* get() override final;
virtual VOP* vget() override final;
#include "hl_reg_ind.h" #include "hl_reg_ind.h"
@ -211,6 +241,14 @@ namespace MultiArrayTools
template <class ROP> template <class ROP>
HighLevelOpHolder<ROP> mkHLO(const ROP& op); HighLevelOpHolder<ROP> mkHLO(const ROP& op);
template <class ROP>
HighLevelOpHolder<ROP> mkHLOV(double val);
extern template HighLevelOpHolder<OpCD> mkHLO(const OpCD& op);
extern template HighLevelOpHolder<OpD> mkHLO(const OpD& op);
extern template HighLevelOpHolder<OpCD> mkHLOV(double val);
extern template HighLevelOpHolder<OpD> mkHLOV(double val);
#define regFunc1(fff) template <class ROP> \ #define regFunc1(fff) template <class ROP> \
HighLevelOpHolder<ROP> hl_##fff (const HighLevelOpHolder<ROP>& in); HighLevelOpHolder<ROP> hl_##fff (const HighLevelOpHolder<ROP>& in);
#include "extensions/math.h" #include "extensions/math.h"

View file

@ -119,7 +119,7 @@ namespace MultiArrayHelper
static inline void setOpPos(OpTuple& ot, const ETuple& et) static inline void setOpPos(OpTuple& ot, const ETuple& et)
{ {
typedef typename std::remove_reference<decltype(std::get<N>(ot))>::type NextOpType; typedef typename std::remove_reference<decltype(std::get<N>(ot))>::type NextOpType;
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions"); static_assert(LAST >= NextOpType::SIZE, "inconsistent array positions");
static constexpr size_t NEXT = LAST - NextOpType::SIZE; static constexpr size_t NEXT = LAST - NextOpType::SIZE;
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) ); std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
PackNum<N-1>::template setOpPos<NEXT,OpTuple,ETuple>(ot, et); PackNum<N-1>::template setOpPos<NEXT,OpTuple,ETuple>(ot, et);

View file

@ -32,4 +32,9 @@ namespace MultiArrayTools
template class HighLevelOpRoot<OpCD>; template class HighLevelOpRoot<OpCD>;
template class HighLevelOpRoot<OpD>; template class HighLevelOpRoot<OpD>;
template HighLevelOpHolder<OpCD> mkHLO(const OpCD& op);
template HighLevelOpHolder<OpD> mkHLO(const OpD& op);
template HighLevelOpHolder<OpCD> mkHLOV(double val);
template HighLevelOpHolder<OpD> mkHLOV(double val);
} }

View file

@ -307,8 +307,10 @@ namespace
auto hop2 = hl_exp(hop1); auto hop2 = hl_exp(hop1);
auto hop4 = hop3 * hop2; auto hop4 = hop3 * hop2;
auto hopr = mkHLO(resx4(i1,di4)); auto hopr = mkHLO(resx4(i1,di4));
auto hop5 = mkHLOV<OpCD>(1.);
auto hop6 = hop4 - hop5;
//hopr.assign( hop4, mi, ic_1, ic_2 ); //hopr.assign( hop4, mi, ic_1, ic_2 );
hopr.xassign( hop4, di4, i1 ); hopr.xassign( hop6, di4, i1 );
auto i2_1 = imap.at("i2_1"); auto i2_1 = imap.at("i2_1");
auto i2_2 = imap.at("i2_2"); auto i2_2 = imap.at("i2_2");
@ -331,11 +333,12 @@ namespace
auto resx3v = xround(resx3.vdata()[jr]); auto resx3v = xround(resx3.vdata()[jr]);
auto resx4v = xround(resx4.vdata()[jr]); auto resx4v = xround(resx4.vdata()[jr]);
auto x12 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2])); auto x12 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2]));
auto x121 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2])-1.);
EXPECT_EQ( resv, x12 ); EXPECT_EQ( resv, x12 );
EXPECT_EQ( resx1v, x12 ); EXPECT_EQ( resx1v, x12 );
EXPECT_EQ( resx2v, x12 ); EXPECT_EQ( resx2v, x12 );
EXPECT_EQ( resx3v, x12 ); EXPECT_EQ( resx3v, x12 );
EXPECT_EQ( resx4v, x12 ); EXPECT_EQ( resx4v, x121 );
} }
} }
} }