dynamic hybrid contraction works

This commit is contained in:
Christian Zimmermann 2020-08-30 14:43:53 +02:00
parent 371107cb5d
commit 782e8555cf
3 changed files with 59 additions and 15 deletions

View file

@ -131,10 +131,24 @@ namespace MultiArrayTools
std::array<size_t,1>({1}), std::array<size_t,1>({0}))
{}
/*
DynamicOuterOp(const std::shared_ptr<DynamicOperationBase<OpH<OperationRoot<typename Operatrion::value_type,Ranges...>>>>& dyn,
const Operation& op, const std::shared_ptr<Indices>&... inds )
: mThreadId(omp_get_thread_num()),
//mDyn(dyn),
mOp(op), mIndices(inds...),
mMa(std::make_shared<MultiArray<T,Ranges...>>(mkArray<T>(inds->range()...))),
mProto(OperationRoot<T,Ranges...>(*mMa,inds...)),
mL(std::make_tuple(*mProto.mOp,mOp), std::make_tuple(inds...),
std::make_tuple(mMa), std::make_tuple(mProto.mOp->assign( mOp, mkMIndex(inds...) )),
std::array<size_t,1>({1}), std::array<size_t,1>({0}))
{}
*/
template <typename T, class Operation, class... Ranges>
OpH<OperationRoot<T,Ranges...>> DynamicOuterOp<T,Operation,Ranges...>::get(const DExtT& pos) const
{
if(mPrev) mPrev.get(pos.expl<ET>());
//if(mPrev) mPrev.get(pos.expl<ET>());
mL(0,pos.expl<ET>());
return mProto; // empty
}
@ -164,7 +178,7 @@ namespace MultiArrayTools
{
return &mProto;
}
/*
template <class Op1, class Op2>
template <class ET>
inline T TwoOp<Op1,Op2>::get(const ET& pos) const
@ -173,7 +187,7 @@ namespace MultiArrayTools
}
*/
template <typename T, class Operation, class... Ranges>
std::shared_ptr<DynamicOperationBase<OpH<OperationRoot<T,Ranges...>>>>
DynamicOuterOp<T,Operation,Ranges...>::deepCopy() const

View file

@ -72,12 +72,12 @@ namespace MultiArrayTools
{
private:
size_t mThreadId;
//std::shared_ptr<DynamicOperationBase<OpH<OperationRoot<T,Ranges...>>>> mDyn;
Operation mOp;
//OperationRoot<T,Ranges...> mProto;
std::tuple<std::shared_ptr<typename Ranges::IndexType>...> mIndices;
std::shared_ptr<MultiArray<T,Ranges...>> mMa;
OpH<OperationRoot<T,Ranges...>> mProto;
std::shared_ptr<DynamicOperationBase<OpH<OperationRoot<T,Ranges...>>>> mPrev;
typedef ILoop<std::tuple<OperationRoot<T,Ranges...>,Operation>,
@ -100,7 +100,10 @@ namespace MultiArrayTools
DynamicOuterOp& operator=(DynamicOuterOp&& in);
DynamicOuterOp(const Operation& op, const std::shared_ptr<typename Ranges::IndexType>&... inds);
/*
DynamicOuterOp(const std::shared_ptr<DynamicOperationBase<OpH<OperationRoot<typename Operatrion::value_type,Ranges...>>>>& dyn,
const Operation& op, const std::shared_ptr<Indices>&... inds );
*/
virtual OpH<OperationRoot<T,Ranges...>> get(const DExtT& pos) const override final;
virtual DynamicOperationBase<OpH<OperationRoot<T,Ranges...>>>& set(const DExtT& pos) override final;
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const override final;
@ -143,7 +146,7 @@ namespace MultiArrayTools
inline DynamicExpression loop(const DynamicExpression& exp) const { return mOp->loop(exp); }
inline const T* data() const { return mOp->data(); }
};
/*
template <class Op1>
class TwoOp : public OperationTemplate<typename Op2::value_type,TwoOp<Op1>>
{
@ -162,7 +165,7 @@ namespace MultiArrayTools
template <class ET>
inline T get(const ET& pos) const;
};
*/
template <class Operation, class... Indices>
auto mkDynOutOp(const Operation& op, const std::shared_ptr<Indices>&... inds)
{
@ -171,7 +174,17 @@ namespace MultiArrayTools
(DynamicOuterOp<typename Operation::value_type,Operation,
typename Indices::RangeType...>(op, inds...));
}
/*
template <class Operation, class... Indices>
auto mkDynOutOp(const std::shared_ptr<DynamicOperationBase<OpH<OperationRoot<typename Operatrion::value_type,Ranges...>>>>& dyn,
const Operation& op, const std::shared_ptr<Indices>&... inds)
{
return DynamicO<OpH<OperationRoot<typename Operation::value_type,
typename Indices::RangeType...>>>
(DynamicOuterOp<typename Operation::value_type,Operation,
typename Indices::RangeType...>(dyn, op, inds...));
}
*/
// Build plan
/*
template <class Operation>

View file

@ -71,9 +71,11 @@ namespace
cr1 = createRangeE<CR>(5);
auto cr2 = createRangeE<CR>(7);
//auto cr2 = createRangeE<CR>(2);
auto cr3 = createRangeE<CR>(11);
auto cr4 = createRangeE<CR>(3);
auto cr5 = createRangeE<CR>(13);
//auto cr5 = createRangeE<CR>(1);
dr1 = createRangeE<DR>(cr2,cr2,cr3,cr4);
//dr1a = createRangeE<DR>(cr2,cr2,cr3);
@ -81,7 +83,7 @@ namespace
dr3 = createRangeE<DR>(cr2,cr5);
dr5 = createRangeE<DR>(cr5);
dr6 = createRangeE<DR>(cr3,cr4);
dr6a = createRangeE<DR>(cr3);
dr6a = createRangeE<DR>(cr3,cr2,cr5);
dr4 = createRangeE<DR>(cr2,cr3,cr4,cr4);
dr4a = createRangeE<DR>(cr2,cr3);
@ -200,7 +202,7 @@ namespace
(*di3)({imap["i2_1"],imap["i5_1"]});
(*di5)({imap["i5_1"]});
(*di6)({imap["i3_1"],imap["i4_1"]});
(*di6a)({imap["i3_1"]});
(*di6a)({imap["i3_1"],imap["i2_1"],imap["i5_1"]});
auto resx1 = res2;
auto resx2 = res2;
@ -211,15 +213,28 @@ namespace
resx2(i1,di6) += mkDynOp((ma1(i1,di1) * ma5(di5)).c(di3));
resx3(i1,di6) += mkDynOp((mkDynOp(ma1(i1,di1)) * mkDynOp(ma5(di5))).c(di3));
auto xx = std::make_shared<decltype(resx4)>(resx4);
auto mi = mkMIndex(i1,di6a);
auto op1 = ma1(i1,di1);
auto op2 = ma5(di5);
auto dop1 = mkDynOutOp(op1 * op2, ci4_1);
auto op3 = *dop1.data()->mOp;
auto dop2 = mkDynOutOp( dop1.c(di3), op3.c(di3), ci4_1 );
auto opr = resx4(i1,di6);
//resx2(i1,di6) += mkDynOp((ma1(i1,di1) * ma5(di5)).c(di3));
auto loop = mkPILoop
( [&opr,&op1,&op2,&xx,&di3,this](){
auto dop1 = mkDynOutOp(op1 * op2, ci4_1);
auto op3 = *dop1.data()->mOp;
auto dop2 = mkDynOutOp( op3, ci4_1 );
return mkGetExpr
(dop1,mkGetExpr
(dop2,mkILoop
(std::make_tuple(opr,*dop2.data()->mOp), std::make_tuple(ci4_1),
std::make_tuple(xx),
std::make_tuple(opr.plus( *dop2.data()->mOp, mkMIndex(ci4_1) )),
std::array<size_t,1>({1}), std::array<size_t,1>({0})))); } );
mi->pifor(1,loop)();
auto i2_1 = imap.at("i2_1");
auto i3_1 = imap.at("i3_1");
@ -238,16 +253,18 @@ namespace
vv += ma1.vdata()[j1] * ma5.vdata()[j2];
}
}
auto resv = xround(res2.vdata()[jr]);
auto resx1v = xround(resx1.vdata()[jr]);
auto resx2v = xround(resx2.vdata()[jr]);
auto resx3v = xround(resx3.vdata()[jr]);
auto resx4v = xround(resx4.vdata()[jr]);
auto x12 = xround(vv);
EXPECT_EQ( resv, x12 );
EXPECT_EQ( resx1v, x12 );
EXPECT_EQ( resx2v, x12 );
EXPECT_EQ( resx3v, x12 );
EXPECT_EQ( resx4v, x12 );
}
}
//std::cout << std::endl;