allow value operations in high level operations
This commit is contained in:
parent
5a309afac6
commit
995b16b51d
5 changed files with 127 additions and 124 deletions
|
@ -74,6 +74,44 @@ namespace MultiArrayTools
|
|||
return &mOp;
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
auto HighLevelOpRoot<ROP>::vget()
|
||||
-> VOP*
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
HighLevelOpValue<ROP>::HighLevelOpValue(const VOP& op) : mOp(op) {}
|
||||
|
||||
template <class ROP>
|
||||
bool HighLevelOpValue<ROP>::root() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
template <class... Inds>
|
||||
auto HighLevelOpValue<ROP>::xcreate(const std::shared_ptr<Inds>&... inds)
|
||||
-> typename B::template RetT<Inds...>
|
||||
{
|
||||
assert(0);
|
||||
return typename B::template RetT<Inds...>();
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
ROP* HighLevelOpValue<ROP>::get()
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
auto HighLevelOpValue<ROP>::vget()
|
||||
-> VOP*
|
||||
{
|
||||
return &mOp;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <size_t N>
|
||||
|
@ -105,10 +143,18 @@ namespace MultiArrayTools
|
|||
(res, in, inds..., op, ops..., dop, dops...);
|
||||
}
|
||||
else {
|
||||
auto& op = *inn->get();
|
||||
typedef typename std::remove_reference<decltype(op)>::type OP;
|
||||
auto op = inn->get();
|
||||
auto vop = inn->vget();
|
||||
typedef typename std::remove_reference<decltype(*op)>::type OP;
|
||||
typedef typename std::remove_reference<decltype(*vop)>::type VOP;
|
||||
if(op != nullptr){
|
||||
Create<N-1>::template cx<Indices...>::template ccx<ROP,OpF,OP,OPs...>::template cccx<M>
|
||||
(res, in, inds..., op, ops..., dops...);
|
||||
(res, in, inds..., *op, ops..., dops...);
|
||||
}
|
||||
else {
|
||||
Create<N-1>::template cx<Indices...>::template ccx<ROP,OpF,VOP,OPs...>::template cccx<M>
|
||||
(res, in, inds..., *vop, ops..., dops...);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -142,8 +188,14 @@ namespace MultiArrayTools
|
|||
res.appendOuterM(dop.op,dops.op...);
|
||||
}
|
||||
else {
|
||||
auto& op = *inn->get();
|
||||
res.op = mkDynOutOp(mkFOp<OpF>(op,ops...), inds...);
|
||||
auto op = inn->get();
|
||||
auto vop = inn->vget();
|
||||
if(op != nullptr){
|
||||
res.op = mkDynOutOp(mkFOp<OpF>(*op,ops...), inds...);
|
||||
}
|
||||
else {
|
||||
res.op = mkDynOutOp(mkFOp<OpF>(*vop,ops...), inds...);
|
||||
}
|
||||
res.appendOuterM(dops.op...);
|
||||
}
|
||||
}
|
||||
|
@ -168,6 +220,14 @@ namespace MultiArrayTools
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
template <class ROP, class OpF, size_t N>
|
||||
auto HighLevelOp<ROP,OpF,N>::vget()
|
||||
-> VOP*
|
||||
{
|
||||
assert(0);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class ROP, class OpF, size_t N>
|
||||
template <class... Inds>
|
||||
auto HighLevelOp<ROP,OpF,N>::xcreate(const std::shared_ptr<Inds>&... inds)
|
||||
|
@ -394,6 +454,14 @@ namespace MultiArrayTools
|
|||
return HighLevelOpHolder<ROP>(std::make_shared<HighLevelOpRoot<ROP>>( op ) );
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
HighLevelOpHolder<ROP> mkHLOV(double val)
|
||||
{
|
||||
return HighLevelOpHolder<ROP>(std::make_shared<HighLevelOpValue<ROP>>
|
||||
( OperationValue<double>(val) ) );
|
||||
}
|
||||
|
||||
|
||||
#define SP " "
|
||||
#define regFunc1(fff) template <class ROP> \
|
||||
HighLevelOpHolder<ROP> hl_##fff (const HighLevelOpHolder<ROP>& in) \
|
||||
|
@ -404,116 +472,5 @@ namespace MultiArrayTools
|
|||
#undef regFunc1
|
||||
#undef SP
|
||||
|
||||
/*
|
||||
template <size_t N>
|
||||
template <class ITuple>
|
||||
inline void SetLInds<N>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di)
|
||||
{
|
||||
constexpr size_t NN = std::tuple_size<ITuple>::value-N-1;
|
||||
const size_t nn = di->dim()-N-1;
|
||||
typedef typename std::remove_reference<decltype(*std::get<NN>(itp))>::type T;
|
||||
std::get<NN>(itp) =
|
||||
std::dynamic_pointer_cast<T>(di->get(nn))->getIndex();
|
||||
SetLInds<N-1>::mkLIT(itp, di);
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
template <class Tar, class ITp, typename... Args>
|
||||
template <class... Is>
|
||||
inline void SetLInds<N>::xx<Tar,ITp,Args...>::
|
||||
assign(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
|
||||
{
|
||||
SetLInds<N-1>::template xx<ITp,Args...>::assign(tar, args..., itp, std::get<N>(itp), is...);
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
template <class Tar, class ITp, typename... Args>
|
||||
template <class... Is>
|
||||
inline void SetLInds<N>::xx<Tar,ITp,Args...>::
|
||||
plus(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
|
||||
{
|
||||
SetLInds<N-1>::template xx<ITp,Args...>::plus(tar, args..., itp, std::get<N>(itp), is...);
|
||||
}
|
||||
|
||||
//template <>
|
||||
template <class ITuple>
|
||||
inline void SetLInds<0>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di)
|
||||
{
|
||||
constexpr size_t NN = std::tuple_size<ITuple>::value-1;
|
||||
const size_t nn = di->dim()-1;
|
||||
typedef typename std::remove_reference<decltype(*std::get<NN>(itp))>::type T;
|
||||
std::get<NN>(itp) =
|
||||
std::dynamic_pointer_cast<T>(di->get(nn))->getIndex();
|
||||
}
|
||||
|
||||
//template <>
|
||||
template <class Tar, class ITp, typename... Args>
|
||||
template <class... Is>
|
||||
inline void SetLInds<0>::xx<Tar,ITp,Args...>::
|
||||
assign(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
|
||||
{
|
||||
tar.assign(args..., std::get<0>(itp), is...);
|
||||
}
|
||||
|
||||
//template <>
|
||||
template <class Tar, class ITp, typename... Args>
|
||||
template <class... Is>
|
||||
inline void SetLInds<0>::xx<Tar,ITp,Args...>::
|
||||
plus(Tar& tar, const Args&... args, const ITp& itp, const std::shared_ptr<Is>&... is)
|
||||
{
|
||||
tar.plus(args..., std::get<0>(itp), is...);
|
||||
}
|
||||
|
||||
template <class ROP, class... Indices>
|
||||
size_t INDS<ROP,Indices...>::CallHLOpBase::depth() const
|
||||
{
|
||||
return mDepth;
|
||||
}
|
||||
|
||||
|
||||
template <class ROP, class... Indices>
|
||||
template <class... LIndices>
|
||||
void INDS<ROP,Indices...>::CallHLOp<LIndices...>::
|
||||
assign(HighLevelOpHolder<ROP>& target, const HighLevelOpHolder<ROP>& source,
|
||||
const std::shared_ptr<Indices>&... is,
|
||||
const std::shared_ptr<DynamicIndex>& di) const
|
||||
{
|
||||
auto ip = di->get(di->dim() - this->depth());
|
||||
auto iregn = ip->regN();
|
||||
if(iregn.type >= 0 and iregn.depth > sizeof...(LIndices)){
|
||||
sNext[iregn.type]->assign(target, source, is..., di);
|
||||
}
|
||||
else {
|
||||
ITuple itp;
|
||||
SetLInds<sizeof...(LIndices)-1>::mkLIT(itp,di);
|
||||
auto mi = mkIndex(is...,mkSubSpaceX(di, di->dim() - this->depth()));
|
||||
SetLInds<sizeof...(LIndices)-1>::
|
||||
template xx<HighLevelOpHolder<ROP>,ITuple,HighLevelOpHolder<ROP>,decltype(mi)>::
|
||||
assign(target, source, mi, itp);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ROP, class... Indices>
|
||||
template <class... LIndices>
|
||||
void INDS<ROP,Indices...>::CallHLOp<LIndices...>::
|
||||
plus(HighLevelOpHolder<ROP>& target, const HighLevelOpHolder<ROP>& source,
|
||||
const std::shared_ptr<Indices>&... is,
|
||||
const std::shared_ptr<DynamicIndex>& di) const
|
||||
{
|
||||
auto ip = di->get(di->dim() - this->depth());
|
||||
auto iregn = ip->regN();
|
||||
if(iregn.type >= 0 and iregn.depth > sizeof...(LIndices)){
|
||||
sNext[iregn.type]->plus(target, source, is..., di);
|
||||
}
|
||||
else {
|
||||
ITuple itp;
|
||||
SetLInds<sizeof...(LIndices)-1>::mkLIT(itp,di);
|
||||
auto mi = mkIndex(is...,mkSubSpaceX(di, di->dim() - this->depth()));
|
||||
SetLInds<sizeof...(LIndices)-1>::
|
||||
template xx<HighLevelOpHolder<ROP>,ITuple,HighLevelOpHolder<ROP>,decltype(mi)>::
|
||||
plus(target, source, mi, itp);
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,8 @@ namespace MultiArrayTools
|
|||
{
|
||||
public:
|
||||
|
||||
typedef OperationValue<double> VOP;
|
||||
|
||||
template <class... Indices>
|
||||
struct RetT
|
||||
{
|
||||
|
@ -59,6 +61,7 @@ namespace MultiArrayTools
|
|||
#undef reg_ind3
|
||||
|
||||
virtual ROP* get() = 0;
|
||||
virtual VOP* vget() = 0;
|
||||
|
||||
};
|
||||
|
||||
|
@ -67,6 +70,7 @@ namespace MultiArrayTools
|
|||
{
|
||||
private:
|
||||
typedef HighLevelOpBase<ROP> B;
|
||||
typedef typename B::VOP VOP;
|
||||
|
||||
template <class... Inds>
|
||||
typename B::template RetT<Inds...> xcreate(const std::shared_ptr<Inds>&... inds);
|
||||
|
@ -91,7 +95,7 @@ namespace MultiArrayTools
|
|||
#include "hl_reg_ind.h"
|
||||
|
||||
virtual ROP* get() override final;
|
||||
|
||||
virtual VOP* vget() override final;
|
||||
|
||||
};
|
||||
|
||||
|
@ -100,6 +104,30 @@ namespace MultiArrayTools
|
|||
extern template class HighLevelOpRoot<OpCD>;
|
||||
extern template class HighLevelOpRoot<OpD>;
|
||||
|
||||
template <class ROP>
|
||||
class HighLevelOpValue : public HighLevelOpBase<ROP>
|
||||
{
|
||||
private:
|
||||
typedef HighLevelOpBase<ROP> B;
|
||||
typedef typename B::VOP VOP;
|
||||
|
||||
template <class... Inds>
|
||||
typename B::template RetT<Inds...> xcreate(const std::shared_ptr<Inds>&... inds);
|
||||
|
||||
VOP mOp;
|
||||
public:
|
||||
|
||||
HighLevelOpValue(const VOP& vop);
|
||||
|
||||
virtual bool root() const override final;
|
||||
|
||||
#include "hl_reg_ind.h"
|
||||
|
||||
virtual ROP* get() override final;
|
||||
virtual VOP* vget() override final;
|
||||
|
||||
};
|
||||
|
||||
template <class OpF, class... Ops>
|
||||
auto mkFOp(const Ops&... ops)
|
||||
{
|
||||
|
@ -113,6 +141,7 @@ namespace MultiArrayTools
|
|||
{
|
||||
public:
|
||||
typedef HighLevelOpBase<ROP> B;
|
||||
typedef typename B::VOP VOP;
|
||||
|
||||
private:
|
||||
std::array<std::shared_ptr<HighLevelOpBase<ROP>>,N> mIn;
|
||||
|
@ -127,6 +156,7 @@ namespace MultiArrayTools
|
|||
virtual bool root() const override final;
|
||||
|
||||
virtual ROP* get() override final;
|
||||
virtual VOP* vget() override final;
|
||||
|
||||
#include "hl_reg_ind.h"
|
||||
|
||||
|
@ -211,6 +241,14 @@ namespace MultiArrayTools
|
|||
template <class ROP>
|
||||
HighLevelOpHolder<ROP> mkHLO(const ROP& op);
|
||||
|
||||
template <class ROP>
|
||||
HighLevelOpHolder<ROP> mkHLOV(double val);
|
||||
|
||||
extern template HighLevelOpHolder<OpCD> mkHLO(const OpCD& op);
|
||||
extern template HighLevelOpHolder<OpD> mkHLO(const OpD& op);
|
||||
extern template HighLevelOpHolder<OpCD> mkHLOV(double val);
|
||||
extern template HighLevelOpHolder<OpD> mkHLOV(double val);
|
||||
|
||||
#define regFunc1(fff) template <class ROP> \
|
||||
HighLevelOpHolder<ROP> hl_##fff (const HighLevelOpHolder<ROP>& in);
|
||||
#include "extensions/math.h"
|
||||
|
|
|
@ -119,7 +119,7 @@ namespace MultiArrayHelper
|
|||
static inline void setOpPos(OpTuple& ot, const ETuple& et)
|
||||
{
|
||||
typedef typename std::remove_reference<decltype(std::get<N>(ot))>::type NextOpType;
|
||||
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
|
||||
static_assert(LAST >= NextOpType::SIZE, "inconsistent array positions");
|
||||
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
|
||||
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
|
||||
PackNum<N-1>::template setOpPos<NEXT,OpTuple,ETuple>(ot, et);
|
||||
|
|
|
@ -32,4 +32,9 @@ namespace MultiArrayTools
|
|||
template class HighLevelOpRoot<OpCD>;
|
||||
template class HighLevelOpRoot<OpD>;
|
||||
|
||||
template HighLevelOpHolder<OpCD> mkHLO(const OpCD& op);
|
||||
template HighLevelOpHolder<OpD> mkHLO(const OpD& op);
|
||||
template HighLevelOpHolder<OpCD> mkHLOV(double val);
|
||||
template HighLevelOpHolder<OpD> mkHLOV(double val);
|
||||
|
||||
}
|
||||
|
|
|
@ -307,8 +307,10 @@ namespace
|
|||
auto hop2 = hl_exp(hop1);
|
||||
auto hop4 = hop3 * hop2;
|
||||
auto hopr = mkHLO(resx4(i1,di4));
|
||||
auto hop5 = mkHLOV<OpCD>(1.);
|
||||
auto hop6 = hop4 - hop5;
|
||||
//hopr.assign( hop4, mi, ic_1, ic_2 );
|
||||
hopr.xassign( hop4, di4, i1 );
|
||||
hopr.xassign( hop6, di4, i1 );
|
||||
|
||||
auto i2_1 = imap.at("i2_1");
|
||||
auto i2_2 = imap.at("i2_2");
|
||||
|
@ -331,11 +333,12 @@ namespace
|
|||
auto resx3v = xround(resx3.vdata()[jr]);
|
||||
auto resx4v = xround(resx4.vdata()[jr]);
|
||||
auto x12 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2]));
|
||||
auto x121 = xround(ma1.vdata()[j1]*exp(ma2.vdata()[j2])-1.);
|
||||
EXPECT_EQ( resv, x12 );
|
||||
EXPECT_EQ( resx1v, x12 );
|
||||
EXPECT_EQ( resx2v, x12 );
|
||||
EXPECT_EQ( resx3v, x12 );
|
||||
EXPECT_EQ( resx4v, x12 );
|
||||
EXPECT_EQ( resx4v, x121 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue