slice contraction (tested, works)

This commit is contained in:
Christian Zimmermann 2018-09-15 01:58:17 +02:00
parent 6ba140aa97
commit b7e40ca71b
4 changed files with 187 additions and 22 deletions

View file

@ -56,6 +56,10 @@ namespace MultiArrayTools
template <typename T, class Op, class IndexType>
class Contraction;
// multi_array_operation.h
template <typename T, class Op, class... Indices>
class SliceContraction;
// slice.h
template <typename T, class... SRanges>
class Slice;

View file

@ -58,8 +58,8 @@ namespace MultiArrayTools
-> ConstSlice<T,typename Indices::RangeType...>;
template <class... Indices>
auto slc(const std::shared_ptr<Indices>&... inds) const
-> SliceContraction<T,typename Indices::RangeType...>;
auto slc(const std::shared_ptr<Indices>&... inds)
-> SliceContraction<T,OperationClass,Indices...>;
private:
friend OperationClass;
@ -157,6 +157,9 @@ namespace MultiArrayTools
template <class ET>
inline T get(ET pos) const;
template <class ET>
inline ConstOperationRoot& set(ET pos);
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
template <class Expr>
@ -169,6 +172,7 @@ namespace MultiArrayTools
//MultiArrayBase<T,Ranges...> const& mArrayRef;
const T* mDataPtr;
mutable IndexType mIndex;
size_t mOff = 0;
//std::shared_ptr<MultiArrayBase<T,Ranges...> > mMaPtr;
};
@ -192,6 +196,9 @@ namespace MultiArrayTools
template <class ET>
inline T get(ET pos) const;
template <class ET>
inline StaticCast& set(ET pos);
auto rootSteps(std::intptr_t iPtrNum = 0) const
-> decltype(mOp.rootSteps(iPtrNum));
@ -225,6 +232,9 @@ namespace MultiArrayTools
template <class ET>
inline value_type get(ET pos) const;
template <class ET>
inline MetaOperationRoot& set(ET pos);
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
template <class Expr>
@ -260,6 +270,9 @@ namespace MultiArrayTools
template <class ET>
inline T get(ET pos) const;
template <class ET>
inline OperationRoot& set(ET pos);
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
template <class Expr>
@ -276,6 +289,7 @@ namespace MultiArrayTools
//MutableMultiArrayBase<T,Ranges...>& mArrayRef;
T* mDataPtr;
mutable IndexType mIndex;
size_t mOff = 0;
};
template <typename T>
@ -294,6 +308,9 @@ namespace MultiArrayTools
template <class ET>
inline T get(ET pos) const;
template <class ET>
inline OperationValue& set(ET pos);
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
template <class Expr>
@ -368,6 +385,9 @@ namespace MultiArrayTools
template <class ET>
inline T get(ET pos) const;
template <class ET>
inline Operation& set(ET pos);
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
-> decltype(PackNum<sizeof...(Ops)-1>::mkSteps(iPtrNum, mOps));
@ -435,6 +455,9 @@ namespace MultiArrayTools
template <class ET>
inline T get(ET pos) const;
template <class ET>
inline Contraction& set(ET pos);
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
-> decltype(mOp.rootSteps(iPtrNum));
@ -443,30 +466,34 @@ namespace MultiArrayTools
};
template <typename T, class Op, class... Indices>
class SliceContraction : public OperationTemplate<MultiArray<T,Indices...>,
SliceContraction<MultiArray<T,Indices...>,Op,Indices...> >
// class SliceContraction : public OperationTemplate
//<MultiArray<T,typename Indices::RangeType...>,
//SliceContraction<MultiArray<T,typename Indices::RangeType...>,Op,Indices...> >
class SliceContraction : public OperationTemplate<T,SliceContraction<T,Op,Indices...> >
{
public:
typedef MultiArray<T,Indices...> value_type;
typedef OperationTemplate<ConstSlice<T,Indices...>,
SliceContraction<ConstSlice<T,Indices...>,Op,Indices...> > OT;
typedef MultiArray<T,typename Indices::RangeType...> value_type;
typedef OperationTemplate<T,SliceContraction<T,Op,Indices...> > OT;
static constexpr size_t SIZE = Op::SIZE;
private:
const Op& mOp;
MultiArray<T,Indices...> mCont;
OperationRoot<T,Indices...> mTarOp;
Op& mOp;
mutable MultiArray<T,typename Indices::RangeType...> mCont;
mutable OperationRoot<T,typename Indices::RangeType...> mTarOp;
public:
typedef decltype(mOp.rootSteps(0)) ETuple;
SliceContraction(const Op& op, const std::shared_ptr<Indices>&... ind);
SliceContraction(Op& op, std::shared_ptr<Indices>... ind);
template <class ET>
inline const value_type& get(ET pos) const;
template <class ET>
inline SliceContraction& set(ET pos);
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
-> decltype(mOp.rootSteps(iPtrNum));
@ -542,6 +569,15 @@ namespace MultiArrayTools
return out;
}
template <typename T, class OperationClass>
template <class... Indices>
auto OperationBase<T,OperationClass>::slc(const std::shared_ptr<Indices>&... inds)
-> SliceContraction<T,OperationClass,Indices...>
{
return SliceContraction<T,OperationClass,Indices...>
(THIS(), inds...);
}
/*****************************************
* OperationMaster::AssignmentExpr *
*****************************************/
@ -643,10 +679,16 @@ namespace MultiArrayTools
template <class ET>
inline T ConstOperationRoot<T,Ranges...>::get(ET pos) const
{
//VCHECK(pos.val());
//VCHECK(mDataPtr);
//VCHECK(mDataPtr[pos.val()])
return mDataPtr[pos.val()];
return mDataPtr[pos.val()+mOff];
}
template <typename T, class... Ranges>
template <class ET>
inline ConstOperationRoot<T,Ranges...>& ConstOperationRoot<T,Ranges...>::set(ET pos)
{
mIndex = pos.val();
mOff = mIndex.pos();
return *this;
}
template <typename T, class... Ranges>
@ -684,6 +726,14 @@ namespace MultiArrayTools
return static_cast<T>( mOp.get(pos) );
}
template <typename T, class Op>
template <class ET>
inline StaticCast<T,Op>& StaticCast<T,Op>::set(ET pos)
{
mOp.set(pos);
return *this;
}
template <typename T, class Op>
auto StaticCast<T,Op>::rootSteps(std::intptr_t iPtrNum) const
-> decltype(mOp.rootSteps(iPtrNum))
@ -719,6 +769,14 @@ namespace MultiArrayTools
return mIndex.meta(pos.val());
}
template <class... Ranges>
template <class ET>
inline MetaOperationRoot<Ranges...>& MetaOperationRoot<Ranges...>::set(ET pos)
{
mIndex = pos.val();
return *this;
}
template <class... Ranges>
MExt<void> MetaOperationRoot<Ranges...>::rootSteps(std::intptr_t iPtrNum) const
{
@ -765,7 +823,16 @@ namespace MultiArrayTools
template <class ET>
inline T OperationRoot<T,Ranges...>::get(ET pos) const
{
return mDataPtr[pos.val()];
return mDataPtr[pos.val()+mOff];
}
template <typename T, class... Ranges>
template <class ET>
inline OperationRoot<T,Ranges...>& OperationRoot<T,Ranges...>::set(ET pos)
{
mIndex = pos.val();
mOff = mIndex.pos();
return *this;
}
template <typename T, class... Ranges>
@ -813,6 +880,13 @@ namespace MultiArrayTools
return mVal;
}
template <typename T>
template <class ET>
inline OperationValue<T>& OperationValue<T>::set(ET pos)
{
return *this;
}
template <typename T>
MExt<void> OperationValue<T>::rootSteps(std::intptr_t iPtrNum) const
{
@ -856,6 +930,14 @@ namespace MultiArrayTools
template mkOpExpr<SIZE,T,ET,OpTuple,OpFunction>(mF, pos, mOps);
}
template <typename T, class OpFunction, class... Ops>
template <class ET>
inline Operation<T,OpFunction,Ops...>& Operation<T,OpFunction,Ops...>::set(ET pos)
{
PackNum<sizeof...(Ops)-1>::setOpPos(mOps,pos);
return *this;
}
template <typename T, class OpFunction, class... Ops>
auto Operation<T,OpFunction,Ops...>::rootSteps(std::intptr_t iPtrNum) const
-> decltype(PackNum<sizeof...(Ops)-1>::mkSteps(iPtrNum, mOps))
@ -889,6 +971,14 @@ namespace MultiArrayTools
return mOp.template get<ET>(pos);
}
template <typename T, class Op, class IndexType>
template <class ET>
inline Contraction<T,Op,IndexType>& Contraction<T,Op,IndexType>::set(ET pos)
{
mOp.set(pos);
return *this;
}
template <typename T, class Op, class IndexType>
auto Contraction<T,Op,IndexType>::rootSteps(std::intptr_t iPtrNum) const
-> decltype(mOp.rootSteps(iPtrNum))
@ -908,19 +998,30 @@ namespace MultiArrayTools
**************************/
template <typename T, class Op, class... Indices>
SliceContraction<T,Op,Indices...>::SliceContraction(const Op& op, const std::shared_ptr<Indices>&... ind) :
SliceContraction<T,Op,Indices...>::SliceContraction(Op& op, std::shared_ptr<Indices>... ind) :
mOp(op),
mInds(ind...) {}
mCont(ind->range()...),
mTarOp(mCont,ind...) {}
// forward loop !!!!
template <typename T, class Op, class... Indices>
template <class ET>
inline const MultiArray<T,Indices...>& SliceContraction<T,Op,Indices...>::get(ET pos) const
inline const MultiArray<T,typename Indices::RangeType...>&
SliceContraction<T,Op,Indices...>::get(ET pos) const
{
mTarOp(mInds) = mOp.set(pos); // SET FUNCTION!!
mCont *= 0; // grrr
mTarOp = mOp.set(pos); // SET FUNCTION!!
return mCont;
}
template <typename T, class Op, class... Indices>
template <class ET>
inline SliceContraction<T,Op,Indices...>& SliceContraction<T,Op,Indices...>::set(ET pos)
{
mOp.set(pos);
return *this;
}
template <typename T, class Op, class... Indices>
auto SliceContraction<T,Op,Indices...>::rootSteps(std::intptr_t iPtrNum) const
-> decltype(mOp.rootSteps(iPtrNum))

View file

@ -114,6 +114,16 @@ namespace MultiArrayHelper
{
return PackNum<N-1>::mkMapOp(ma, itp, std::get<N>(itp), inds...);
}
template <size_t LAST,class OpTuple, class ETuple>
static inline void setOpPos(const OpTuple& ot, const ETuple& et)
{
typedef typename std::remove_reference<decltype(std::get<N>(ot))>::type NextOpType;
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
PackNum<N-1>::template setOpPos<NEXT>(ot, et);
}
};
template<>
@ -192,6 +202,15 @@ namespace MultiArrayHelper
return ma.exec(std::get<0>(itp), inds...);
}
template <size_t LAST,class OpTuple, class ETuple>
static inline void setOpPos(const OpTuple& ot, const ETuple& et)
{
typedef typename std::remove_reference<decltype(std::get<0>(et))>::type NextOpType;
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
std::get<0>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
}
};

View file

@ -668,7 +668,11 @@ namespace {
MultiArray<double,MRange> ma1(mr1ptr, v3);
MultiArray<double,SRange> ma2(sr2ptr, v1);
MultiArray<double,SRange> ma3(sr4ptr, v4);
MultiArray<double,SRange,SRange,SRange,SRange> ma4(sr1ptr,sr2ptr,sr3ptr,sr4ptr);
MultiArray<double,SRange> ma5(sr1ptr, v3);
MultiArray<double,SRange> ma6(sr3ptr, v2);
auto si0 = MAT::getIndex( sr1ptr );
auto si1 = MAT::getIndex( sr2ptr );
auto si2 = MAT::getIndex( sr3ptr );
auto si3 = MAT::getIndex( sr4ptr );
@ -676,6 +680,7 @@ namespace {
mi->operator()(si1,si2);
res(mi,si3) = ma1(mi) + ma2(si1) + ma3(si3);
ma4(si0,si1,si2,si3) = ma5(si0)*ma2(si1)*ma6(si2)*ma3(si3);
EXPECT_EQ( xround( res.at(mkt(mkt('1','a'),'A')) ), xround(0.353 + 2.917 + 1.470) );
EXPECT_EQ( xround( res.at(mkt(mkt('1','a'),'B')) ), xround(0.353 + 2.917 + 2.210) );
@ -691,6 +696,42 @@ namespace {
EXPECT_EQ( xround( res.at(mkt(mkt('3','a'),'B')) ), xround(9.243 + 0.373 + 2.210) );
EXPECT_EQ( xround( res.at(mkt(mkt('3','b'),'A')) ), xround(2.911 + 0.373 + 1.470) );
EXPECT_EQ( xround( res.at(mkt(mkt('3','b'),'B')) ), xround(2.911 + 0.373 + 2.210) );
MultiArray<MultiArray<double,SRange,SRange>,SRange,SRange> ma7(sr2ptr,sr4ptr);
ma7(si1,si3) = ma4(si0,si1,si2,si3).slc(si0,si2);
si1->at('1');
si3->at('A');
Slice<double,SRange,SRange> sl(sr1ptr,sr3ptr);
sl.define(si2,si3) = ma4(si0,si1,si2,si3);
EXPECT_EQ( xround( ma7.at(mkt('1','A')).at(mkt('x','a')) ),
xround( ma4.at(mkt('x','1','a','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('2','A')).at(mkt('x','a')) ),
xround( ma4.at(mkt('x','2','a','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('3','A')).at(mkt('x','a')) ),
xround( ma4.at(mkt('x','3','a','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('1','A')).at(mkt('x','b')) ),
xround( ma4.at(mkt('x','1','b','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('2','A')).at(mkt('x','b')) ),
xround( ma4.at(mkt('x','2','b','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('3','A')).at(mkt('x','b')) ),
xround( ma4.at(mkt('x','3','b','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('1','A')).at(mkt('l','b')) ),
xround( ma4.at(mkt('l','1','b','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('2','A')).at(mkt('l','b')) ),
xround( ma4.at(mkt('l','2','b','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('3','A')).at(mkt('l','b')) ),
xround( ma4.at(mkt('l','3','b','A')) ) );
EXPECT_EQ( xround( ma7.at(mkt('1','B')).at(mkt('l','b')) ),
xround( ma4.at(mkt('l','1','b','B')) ) );
EXPECT_EQ( xround( ma7.at(mkt('2','B')).at(mkt('l','b')) ),
xround( ma4.at(mkt('l','2','b','B')) ) );
EXPECT_EQ( xround( ma7.at(mkt('3','B')).at(mkt('l','b')) ),
xround( ma4.at(mkt('l','3','b','B')) ) );
}
/*
TEST_F(OpTest_MDim, ExecAnonOp1)