OpTest_Dyn, Multiply: add remaining tests: work
This commit is contained in:
parent
d2e7cb3a63
commit
6444f971a6
6 changed files with 27 additions and 16 deletions
|
@ -4,7 +4,7 @@
|
|||
namespace MultiArrayTools
|
||||
{
|
||||
template <typename T, class Operation>
|
||||
const T& DynamicOperation<T,Operation>::get(const DExtT& pos) const
|
||||
T DynamicOperation<T,Operation>::get(const DExtT& pos) const
|
||||
{
|
||||
return mOp.get(pos.expl<ET>());
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace MultiArrayTools
|
|||
DynamicOperationBase& operator=(const DynamicOperationBase& in) = default;
|
||||
DynamicOperationBase& operator=(DynamicOperationBase&& in) = default;
|
||||
|
||||
virtual const T& get(const DExtT& pos) const = 0;
|
||||
virtual T get(const DExtT& pos) const = 0;
|
||||
virtual DynamicOperationBase& set(const DExtT& pos) = 0;
|
||||
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const = 0;
|
||||
virtual DynamicExpression loop(const DynamicExpression& exp) const = 0;
|
||||
|
@ -50,7 +50,7 @@ namespace MultiArrayTools
|
|||
|
||||
DynamicOperation(const Operation& op) : mOp(op) {}
|
||||
|
||||
virtual const T& get(const DExtT& pos) const override final;
|
||||
virtual T get(const DExtT& pos) const override final;
|
||||
virtual DynamicOperationBase<T>& set(const DExtT& pos) override final;
|
||||
virtual DExtT rootSteps(std::intptr_t iPtrNum = 0) const override final;
|
||||
virtual DynamicExpression loop(const DynamicExpression& exp) const override final;
|
||||
|
@ -81,8 +81,9 @@ namespace MultiArrayTools
|
|||
DynamicO(const Op& op) : mOp(std::make_shared<DynamicOperation<T,Op>>(op)) {}
|
||||
|
||||
template <class X>
|
||||
inline const T& get(const DExtTX<X>& pos) const { return mOp->get(pos.reduce()); }
|
||||
inline DynamicO& set(const DExtT& pos) { return mOp->set(pos); }
|
||||
inline T get(const DExtTX<X>& pos) const { return mOp->get(pos.reduce()); }
|
||||
template <class X>
|
||||
inline DynamicO& set(const DExtTX<X>& pos) { mOp->set(pos.reduce()); return *this; }
|
||||
inline DExtT rootSteps(std::intptr_t iPtrNum = 0) const { return mOp->rootSteps(iPtrNum); }
|
||||
inline DynamicExpression loop(const DynamicExpression& exp) const { return mOp->loop(exp); }
|
||||
inline const T* data() const { return mOp->data(); }
|
||||
|
|
|
@ -753,7 +753,8 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline Operation<T,OpFunction,Ops...>& Operation<T,OpFunction,Ops...>::set(ET pos)
|
||||
{
|
||||
PackNum<sizeof...(Ops)-1>::setOpPos(mOps,pos);
|
||||
typedef std::tuple<Ops...> OpTuple;
|
||||
PackNum<sizeof...(Ops)-1>::template setOpPos<SIZE,OpTuple,ET>(mOps,pos);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -530,6 +530,7 @@ namespace MultiArrayTools
|
|||
auto loop(Expr exp) const
|
||||
-> decltype(PackNum<sizeof...(Ops)-1>::mkLoop( mOps, exp));
|
||||
|
||||
T* data() const { assert(0); return nullptr; }
|
||||
};
|
||||
|
||||
namespace
|
||||
|
|
|
@ -122,7 +122,7 @@ namespace MultiArrayHelper
|
|||
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
|
||||
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
|
||||
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
|
||||
PackNum<N-1>::template setOpPos<NEXT>(ot, et);
|
||||
PackNum<N-1>::template setOpPos<NEXT,OpTuple,ETuple>(ot, et);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -205,9 +205,9 @@ namespace MultiArrayHelper
|
|||
template <size_t LAST,class OpTuple, class ETuple>
|
||||
static inline void setOpPos(OpTuple& ot, const ETuple& et)
|
||||
{
|
||||
typedef typename std::remove_reference<decltype(std::get<0>(et))>::type NextOpType;
|
||||
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
|
||||
typedef typename std::remove_reference<decltype(std::get<0>(ot))>::type NextOpType;
|
||||
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
|
||||
static_assert(NEXT == 0, "inconsistent array positions");
|
||||
std::get<0>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
|
||||
}
|
||||
|
||||
|
|
|
@ -104,9 +104,12 @@ namespace
|
|||
(*di4)({imap["i2_1"],imap["i3_1"],imap["i4_1"],imap["i4_2"]});
|
||||
|
||||
auto resx1 = res1;
|
||||
auto resx2 = res1;
|
||||
auto resx3 = res1;
|
||||
res1(i1,di4) = ma1(i1,di1) * ma2(i1,di2);
|
||||
|
||||
resx1(i1,di4) = mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2));
|
||||
resx2(i1,di4) = mkDynOp(ma1(i1,di1) * ma2(i1,di2));
|
||||
resx3(i1,di4) = mkDynOp(mkDynOp(ma1(i1,di1)) * mkDynOp(ma2(i1,di2)));
|
||||
|
||||
auto i2_1 = imap.at("i2_1");
|
||||
auto i2_2 = imap.at("i2_2");
|
||||
|
@ -123,10 +126,15 @@ namespace
|
|||
const size_t jr = (((ii1*i2_1->max() + ii2_1)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1)*i4_2->max() + ii4_2;
|
||||
const size_t j1 = (((ii1*i2_1->max() + ii2_1)*i2_2->max() + ii2_2)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1;
|
||||
const size_t j2 = ((ii1*i3_1->max() + ii3_1)*i3_1->max() + ii3_1)*i4_2->max() + ii4_2;
|
||||
auto resx = xround(res1.vdata()[jr]);
|
||||
//std::cout << resx << " ";
|
||||
auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]);
|
||||
EXPECT_EQ( resx, x12 );
|
||||
auto resv = xround(res1.vdata()[jr]);
|
||||
auto resx1v = xround(resx1.vdata()[jr]);
|
||||
auto resx2v = xround(resx2.vdata()[jr]);
|
||||
auto resx3v = xround(resx3.vdata()[jr]);
|
||||
auto x12 = xround(ma1.vdata()[j1]*ma2.vdata()[j2]);
|
||||
EXPECT_EQ( resv, x12 );
|
||||
EXPECT_EQ( resx1v, x12 );
|
||||
EXPECT_EQ( resx2v, x12 );
|
||||
EXPECT_EQ( resx3v, x12 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue