OpTest_Dyn, Multiply: add remaining tests: work

This commit is contained in:
Christian Zimmermann 2020-08-24 17:39:56 +02:00
parent d2e7cb3a63
commit 6444f971a6
6 changed files with 27 additions and 16 deletions

View file

@ -4,7 +4,7 @@
namespace MultiArrayTools namespace MultiArrayTools
{ {
template <typename T, class Operation> template <typename T, class Operation>
const T& DynamicOperation<T,Operation>::get(const DExtT& pos) const T DynamicOperation<T,Operation>::get(const DExtT& pos) const
{ {
return mOp.get(pos.expl<ET>()); return mOp.get(pos.expl<ET>());
} }

View file

@ -24,7 +24,7 @@ namespace MultiArrayTools
DynamicOperationBase& operator=(const DynamicOperationBase& in) = default; DynamicOperationBase& operator=(const DynamicOperationBase& in) = default;
DynamicOperationBase& operator=(DynamicOperationBase&& in) = default; DynamicOperationBase& operator=(DynamicOperationBase&& in) = default;
virtual const T& get(const DExtT& pos) const = 0; virtual T get(const DExtT& pos) const = 0;
virtual DynamicOperationBase& set(const DExtT& pos) = 0; virtual DynamicOperationBase& set(const DExtT& pos) = 0;
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const = 0; virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const = 0;
virtual DynamicExpression loop(const DynamicExpression& exp) const = 0; virtual DynamicExpression loop(const DynamicExpression& exp) const = 0;
@ -50,7 +50,7 @@ namespace MultiArrayTools
DynamicOperation(const Operation& op) : mOp(op) {} DynamicOperation(const Operation& op) : mOp(op) {}
virtual const T& get(const DExtT& pos) const override final; virtual T get(const DExtT& pos) const override final;
virtual DynamicOperationBase<T>& set(const DExtT& pos) override final; virtual DynamicOperationBase<T>& set(const DExtT& pos) override final;
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const override final; virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const override final;
virtual DynamicExpression loop(const DynamicExpression& exp) const override final; virtual DynamicExpression loop(const DynamicExpression& exp) const override final;
@ -81,8 +81,9 @@ namespace MultiArrayTools
DynamicO(const Op& op) : mOp(std::make_shared<DynamicOperation<T,Op>>(op)) {} DynamicO(const Op& op) : mOp(std::make_shared<DynamicOperation<T,Op>>(op)) {}
template <class X> template <class X>
inline const T& get(const DExtTX<X>& pos) const { return mOp->get(pos.reduce()); } inline T get(const DExtTX<X>& pos) const { return mOp->get(pos.reduce()); }
inline DynamicO& set(const DExtT& pos) { return mOp->set(pos); } template <class X>
inline DynamicO& set(const DExtTX<X>& pos) { mOp->set(pos.reduce()); return *this; }
inline DExtT rootSteps(std::intptr_t iPtrNum = 0) const { return mOp->rootSteps(iPtrNum); } inline DExtT rootSteps(std::intptr_t iPtrNum = 0) const { return mOp->rootSteps(iPtrNum); }
inline DynamicExpression loop(const DynamicExpression& exp) const { return mOp->loop(exp); } inline DynamicExpression loop(const DynamicExpression& exp) const { return mOp->loop(exp); }
inline const T* data() const { return mOp->data(); } inline const T* data() const { return mOp->data(); }

View file

@ -753,7 +753,8 @@ namespace MultiArrayTools
template <class ET> template <class ET>
inline Operation<T,OpFunction,Ops...>& Operation<T,OpFunction,Ops...>::set(ET pos) inline Operation<T,OpFunction,Ops...>& Operation<T,OpFunction,Ops...>::set(ET pos)
{ {
PackNum<sizeof...(Ops)-1>::setOpPos(mOps,pos); typedef std::tuple<Ops...> OpTuple;
PackNum<sizeof...(Ops)-1>::template setOpPos<SIZE,OpTuple,ET>(mOps,pos);
return *this; return *this;
} }

View file

@ -530,6 +530,7 @@ namespace MultiArrayTools
auto loop(Expr exp) const auto loop(Expr exp) const
-> decltype(PackNum<sizeof...(Ops)-1>::mkLoop( mOps, exp)); -> decltype(PackNum<sizeof...(Ops)-1>::mkLoop( mOps, exp));
T* data() const { assert(0); return nullptr; }
}; };
namespace namespace

View file

@ -122,7 +122,7 @@ namespace MultiArrayHelper
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions"); static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
static constexpr size_t NEXT = LAST - NextOpType::SIZE; static constexpr size_t NEXT = LAST - NextOpType::SIZE;
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) ); std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
PackNum<N-1>::template setOpPos<NEXT>(ot, et); PackNum<N-1>::template setOpPos<NEXT,OpTuple,ETuple>(ot, et);
} }
}; };
@ -205,9 +205,9 @@ namespace MultiArrayHelper
template <size_t LAST,class OpTuple, class ETuple> template <size_t LAST,class OpTuple, class ETuple>
static inline void setOpPos(OpTuple& ot, const ETuple& et) static inline void setOpPos(OpTuple& ot, const ETuple& et)
{ {
typedef typename std::remove_reference<decltype(std::get<0>(et))>::type NextOpType; typedef typename std::remove_reference<decltype(std::get<0>(ot))>::type NextOpType;
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
static constexpr size_t NEXT = LAST - NextOpType::SIZE; static constexpr size_t NEXT = LAST - NextOpType::SIZE;
static_assert(NEXT == 0, "inconsistent array positions");
std::get<0>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) ); std::get<0>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
} }

View file

@ -104,9 +104,12 @@ namespace
(*di4)({imap["i2_1"],imap["i3_1"],imap["i4_1"],imap["i4_2"]}); (*di4)({imap["i2_1"],imap["i3_1"],imap["i4_1"],imap["i4_2"]});
auto resx1 = res1; auto resx1 = res1;
auto resx2 = res1;
auto resx3 = res1;
res1(i1,di4) = ma1(i1,di1) * ma2(i1,di2); res1(i1,di4) = ma1(i1,di1) * ma2(i1,di2);
resx1(i1,di4) = mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2)); resx1(i1,di4) = mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2));
resx2(i1,di4) = mkDynOp(ma1(i1,di1) * ma2(i1,di2));
resx3(i1,di4) = mkDynOp(mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2)));
auto i2_1 = imap.at("i2_1"); auto i2_1 = imap.at("i2_1");
auto i2_2 = imap.at("i2_2"); auto i2_2 = imap.at("i2_2");
@ -123,10 +126,15 @@ namespace
const size_t jr = (((ii1*i2_1->max() + ii2_1)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1)*i4_2->max() + ii4_2; const size_t jr = (((ii1*i2_1->max() + ii2_1)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1)*i4_2->max() + ii4_2;
const size_t j1 = (((ii1*i2_1->max() + ii2_1)*i2_2->max() + ii2_2)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1; const size_t j1 = (((ii1*i2_1->max() + ii2_1)*i2_2->max() + ii2_2)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1;
const size_t j2 = ((ii1*i3_1->max() + ii3_1)*i3_1->max() + ii3_1)*i4_2->max() + ii4_2; const size_t j2 = ((ii1*i3_1->max() + ii3_1)*i3_1->max() + ii3_1)*i4_2->max() + ii4_2;
auto resx = xround(res1.vdata()[jr]); auto resv = xround(res1.vdata()[jr]);
//std::cout << resx << " "; auto resx1v = xround(resx1.vdata()[jr]);
auto resx2v = xround(resx2.vdata()[jr]);
auto resx3v = xround(resx3.vdata()[jr]);
auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]); auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]);
EXPECT_EQ( resx, x12 ); EXPECT_EQ( resv, x12 );
EXPECT_EQ( resx1v, x12 );
EXPECT_EQ( resx2v, x12 );
EXPECT_EQ( resx3v, x12 );
} }
} }
} }