OpTest_Dyn, Multiply: add remaining tests: work

This commit is contained in:
Christian Zimmermann 2020-08-24 17:39:56 +02:00
parent d2e7cb3a63
commit 6444f971a6
6 changed files with 27 additions and 16 deletions

View file

@ -4,7 +4,7 @@
namespace MultiArrayTools
{
template <typename T, class Operation>
const T& DynamicOperation<T,Operation>::get(const DExtT& pos) const
T DynamicOperation<T,Operation>::get(const DExtT& pos) const
{
return mOp.get(pos.expl<ET>());
}

View file

@ -24,7 +24,7 @@ namespace MultiArrayTools
DynamicOperationBase& operator=(const DynamicOperationBase& in) = default;
DynamicOperationBase& operator=(DynamicOperationBase&& in) = default;
virtual const T& get(const DExtT& pos) const = 0;
virtual T get(const DExtT& pos) const = 0;
virtual DynamicOperationBase& set(const DExtT& pos) = 0;
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const = 0;
virtual DynamicExpression loop(const DynamicExpression& exp) const = 0;
@ -50,7 +50,7 @@ namespace MultiArrayTools
DynamicOperation(const Operation& op) : mOp(op) {}
virtual const T& get(const DExtT& pos) const override final;
virtual T get(const DExtT& pos) const override final;
virtual DynamicOperationBase<T>& set(const DExtT& pos) override final;
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const override final;
virtual DynamicExpression loop(const DynamicExpression& exp) const override final;
@ -81,8 +81,9 @@ namespace MultiArrayTools
DynamicO(const Op& op) : mOp(std::make_shared<DynamicOperation<T,Op>>(op)) {}
template <class X>
inline const T& get(const DExtTX<X>& pos) const { return mOp->get(pos.reduce()); }
inline DynamicO& set(const DExtT& pos) { return mOp->set(pos); }
inline T get(const DExtTX<X>& pos) const { return mOp->get(pos.reduce()); }
template <class X>
inline DynamicO& set(const DExtTX<X>& pos) { mOp->set(pos.reduce()); return *this; }
inline DExtT rootSteps(std::intptr_t iPtrNum = 0) const { return mOp->rootSteps(iPtrNum); }
inline DynamicExpression loop(const DynamicExpression& exp) const { return mOp->loop(exp); }
inline const T* data() const { return mOp->data(); }

View file

@ -753,7 +753,8 @@ namespace MultiArrayTools
template <class ET>
inline Operation<T,OpFunction,Ops...>& Operation<T,OpFunction,Ops...>::set(ET pos)
{
PackNum<sizeof...(Ops)-1>::setOpPos(mOps,pos);
typedef std::tuple<Ops...> OpTuple;
PackNum<sizeof...(Ops)-1>::template setOpPos<SIZE,OpTuple,ET>(mOps,pos);
return *this;
}

View file

@ -530,6 +530,7 @@ namespace MultiArrayTools
auto loop(Expr exp) const
-> decltype(PackNum<sizeof...(Ops)-1>::mkLoop( mOps, exp));
T* data() const { assert(0); return nullptr; }
};
namespace

View file

@ -122,7 +122,7 @@ namespace MultiArrayHelper
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
PackNum<N-1>::template setOpPos<NEXT>(ot, et);
PackNum<N-1>::template setOpPos<NEXT,OpTuple,ETuple>(ot, et);
}
};
@ -205,9 +205,9 @@ namespace MultiArrayHelper
template <size_t LAST,class OpTuple, class ETuple>
static inline void setOpPos(OpTuple& ot, const ETuple& et)
{
typedef typename std::remove_reference<decltype(std::get<0>(et))>::type NextOpType;
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
typedef typename std::remove_reference<decltype(std::get<0>(ot))>::type NextOpType;
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
static_assert(NEXT == 0, "inconsistent array positions");
std::get<0>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
}

View file

@ -104,9 +104,12 @@ namespace
(*di4)({imap["i2_1"],imap["i3_1"],imap["i4_1"],imap["i4_2"]});
auto resx1 = res1;
auto resx2 = res1;
auto resx3 = res1;
res1(i1,di4) = ma1(i1,di1) * ma2(i1,di2);
resx1(i1,di4) = mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2));
resx2(i1,di4) = mkDynOp(ma1(i1,di1) * ma2(i1,di2));
resx3(i1,di4) = mkDynOp(mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2)));
auto i2_1 = imap.at("i2_1");
auto i2_2 = imap.at("i2_2");
@ -123,10 +126,15 @@ namespace
const size_t jr = (((ii1*i2_1->max() + ii2_1)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1)*i4_2->max() + ii4_2;
const size_t j1 = (((ii1*i2_1->max() + ii2_1)*i2_2->max() + ii2_2)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1;
const size_t j2 = ((ii1*i3_1->max() + ii3_1)*i3_1->max() + ii3_1)*i4_2->max() + ii4_2;
auto resx = xround(res1.vdata()[jr]);
//std::cout << resx << " ";
auto resv = xround(res1.vdata()[jr]);
auto resx1v = xround(resx1.vdata()[jr]);
auto resx2v = xround(resx2.vdata()[jr]);
auto resx3v = xround(resx3.vdata()[jr]);
auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]);
EXPECT_EQ( resx, x12 );
EXPECT_EQ( resv, x12 );
EXPECT_EQ( resx1v, x12 );
EXPECT_EQ( resx2v, x12 );
EXPECT_EQ( resx3v, x12 );
}
}
}