slice contraction (tested, works)
This commit is contained in:
parent
6ba140aa97
commit
b7e40ca71b
4 changed files with 187 additions and 22 deletions
|
@ -56,6 +56,10 @@ namespace MultiArrayTools
|
|||
template <typename T, class Op, class IndexType>
|
||||
class Contraction;
|
||||
|
||||
// multi_array_operation.h
|
||||
template <typename T, class Op, class... Indices>
|
||||
class SliceContraction;
|
||||
|
||||
// slice.h
|
||||
template <typename T, class... SRanges>
|
||||
class Slice;
|
||||
|
|
|
@ -58,8 +58,8 @@ namespace MultiArrayTools
|
|||
-> ConstSlice<T,typename Indices::RangeType...>;
|
||||
|
||||
template <class... Indices>
|
||||
auto slc(const std::shared_ptr<Indices>&... inds) const
|
||||
-> SliceContraction<T,typename Indices::RangeType...>;
|
||||
auto slc(const std::shared_ptr<Indices>&... inds)
|
||||
-> SliceContraction<T,OperationClass,Indices...>;
|
||||
|
||||
private:
|
||||
friend OperationClass;
|
||||
|
@ -157,6 +157,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline ConstOperationRoot& set(ET pos);
|
||||
|
||||
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
|
||||
|
||||
template <class Expr>
|
||||
|
@ -169,6 +172,7 @@ namespace MultiArrayTools
|
|||
//MultiArrayBase<T,Ranges...> const& mArrayRef;
|
||||
const T* mDataPtr;
|
||||
mutable IndexType mIndex;
|
||||
size_t mOff = 0;
|
||||
//std::shared_ptr<MultiArrayBase<T,Ranges...> > mMaPtr;
|
||||
};
|
||||
|
||||
|
@ -192,6 +196,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline StaticCast& set(ET pos);
|
||||
|
||||
auto rootSteps(std::intptr_t iPtrNum = 0) const
|
||||
-> decltype(mOp.rootSteps(iPtrNum));
|
||||
|
||||
|
@ -225,6 +232,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline value_type get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline MetaOperationRoot& set(ET pos);
|
||||
|
||||
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
|
||||
|
||||
template <class Expr>
|
||||
|
@ -260,6 +270,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline OperationRoot& set(ET pos);
|
||||
|
||||
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
|
||||
|
||||
template <class Expr>
|
||||
|
@ -276,6 +289,7 @@ namespace MultiArrayTools
|
|||
//MutableMultiArrayBase<T,Ranges...>& mArrayRef;
|
||||
T* mDataPtr;
|
||||
mutable IndexType mIndex;
|
||||
size_t mOff = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -294,6 +308,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline OperationValue& set(ET pos);
|
||||
|
||||
MExt<void> rootSteps(std::intptr_t iPtrNum = 0) const; // nullptr for simple usage with decltype
|
||||
|
||||
template <class Expr>
|
||||
|
@ -368,6 +385,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline Operation& set(ET pos);
|
||||
|
||||
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
|
||||
-> decltype(PackNum<sizeof...(Ops)-1>::mkSteps(iPtrNum, mOps));
|
||||
|
||||
|
@ -435,6 +455,9 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline Contraction& set(ET pos);
|
||||
|
||||
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
|
||||
-> decltype(mOp.rootSteps(iPtrNum));
|
||||
|
||||
|
@ -443,30 +466,34 @@ namespace MultiArrayTools
|
|||
};
|
||||
|
||||
template <typename T, class Op, class... Indices>
|
||||
class SliceContraction : public OperationTemplate<MultiArray<T,Indices...>,
|
||||
SliceContraction<MultiArray<T,Indices...>,Op,Indices...> >
|
||||
// class SliceContraction : public OperationTemplate
|
||||
//<MultiArray<T,typename Indices::RangeType...>,
|
||||
//SliceContraction<MultiArray<T,typename Indices::RangeType...>,Op,Indices...> >
|
||||
class SliceContraction : public OperationTemplate<T,SliceContraction<T,Op,Indices...> >
|
||||
{
|
||||
public:
|
||||
typedef MultiArray<T,Indices...> value_type;
|
||||
typedef OperationTemplate<ConstSlice<T,Indices...>,
|
||||
SliceContraction<ConstSlice<T,Indices...>,Op,Indices...> > OT;
|
||||
typedef MultiArray<T,typename Indices::RangeType...> value_type;
|
||||
typedef OperationTemplate<T,SliceContraction<T,Op,Indices...> > OT;
|
||||
|
||||
static constexpr size_t SIZE = Op::SIZE;
|
||||
|
||||
private:
|
||||
|
||||
const Op& mOp;
|
||||
MultiArray<T,Indices...> mCont;
|
||||
OperationRoot<T,Indices...> mTarOp;
|
||||
Op& mOp;
|
||||
mutable MultiArray<T,typename Indices::RangeType...> mCont;
|
||||
mutable OperationRoot<T,typename Indices::RangeType...> mTarOp;
|
||||
|
||||
public:
|
||||
typedef decltype(mOp.rootSteps(0)) ETuple;
|
||||
|
||||
SliceContraction(const Op& op, const std::shared_ptr<Indices>&... ind);
|
||||
SliceContraction(Op& op, std::shared_ptr<Indices>... ind);
|
||||
|
||||
template <class ET>
|
||||
inline const value_type& get(ET pos) const;
|
||||
|
||||
template <class ET>
|
||||
inline SliceContraction& set(ET pos);
|
||||
|
||||
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
|
||||
-> decltype(mOp.rootSteps(iPtrNum));
|
||||
|
||||
|
@ -542,6 +569,15 @@ namespace MultiArrayTools
|
|||
return out;
|
||||
}
|
||||
|
||||
template <typename T, class OperationClass>
|
||||
template <class... Indices>
|
||||
auto OperationBase<T,OperationClass>::slc(const std::shared_ptr<Indices>&... inds)
|
||||
-> SliceContraction<T,OperationClass,Indices...>
|
||||
{
|
||||
return SliceContraction<T,OperationClass,Indices...>
|
||||
(THIS(), inds...);
|
||||
}
|
||||
|
||||
/*****************************************
|
||||
* OperationMaster::AssignmentExpr *
|
||||
*****************************************/
|
||||
|
@ -643,10 +679,16 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T ConstOperationRoot<T,Ranges...>::get(ET pos) const
|
||||
{
|
||||
//VCHECK(pos.val());
|
||||
//VCHECK(mDataPtr);
|
||||
//VCHECK(mDataPtr[pos.val()])
|
||||
return mDataPtr[pos.val()];
|
||||
return mDataPtr[pos.val()+mOff];
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
template <class ET>
|
||||
inline ConstOperationRoot<T,Ranges...>& ConstOperationRoot<T,Ranges...>::set(ET pos)
|
||||
{
|
||||
mIndex = pos.val();
|
||||
mOff = mIndex.pos();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
|
@ -684,6 +726,14 @@ namespace MultiArrayTools
|
|||
return static_cast<T>( mOp.get(pos) );
|
||||
}
|
||||
|
||||
template <typename T, class Op>
|
||||
template <class ET>
|
||||
inline StaticCast<T,Op>& StaticCast<T,Op>::set(ET pos)
|
||||
{
|
||||
mOp.set(pos);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, class Op>
|
||||
auto StaticCast<T,Op>::rootSteps(std::intptr_t iPtrNum) const
|
||||
-> decltype(mOp.rootSteps(iPtrNum))
|
||||
|
@ -719,6 +769,14 @@ namespace MultiArrayTools
|
|||
return mIndex.meta(pos.val());
|
||||
}
|
||||
|
||||
template <class... Ranges>
|
||||
template <class ET>
|
||||
inline MetaOperationRoot<Ranges...>& MetaOperationRoot<Ranges...>::set(ET pos)
|
||||
{
|
||||
mIndex = pos.val();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class... Ranges>
|
||||
MExt<void> MetaOperationRoot<Ranges...>::rootSteps(std::intptr_t iPtrNum) const
|
||||
{
|
||||
|
@ -765,7 +823,16 @@ namespace MultiArrayTools
|
|||
template <class ET>
|
||||
inline T OperationRoot<T,Ranges...>::get(ET pos) const
|
||||
{
|
||||
return mDataPtr[pos.val()];
|
||||
return mDataPtr[pos.val()+mOff];
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
template <class ET>
|
||||
inline OperationRoot<T,Ranges...>& OperationRoot<T,Ranges...>::set(ET pos)
|
||||
{
|
||||
mIndex = pos.val();
|
||||
mOff = mIndex.pos();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
|
@ -813,6 +880,13 @@ namespace MultiArrayTools
|
|||
return mVal;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <class ET>
|
||||
inline OperationValue<T>& OperationValue<T>::set(ET pos)
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MExt<void> OperationValue<T>::rootSteps(std::intptr_t iPtrNum) const
|
||||
{
|
||||
|
@ -856,6 +930,14 @@ namespace MultiArrayTools
|
|||
template mkOpExpr<SIZE,T,ET,OpTuple,OpFunction>(mF, pos, mOps);
|
||||
}
|
||||
|
||||
template <typename T, class OpFunction, class... Ops>
|
||||
template <class ET>
|
||||
inline Operation<T,OpFunction,Ops...>& Operation<T,OpFunction,Ops...>::set(ET pos)
|
||||
{
|
||||
PackNum<sizeof...(Ops)-1>::setOpPos(mOps,pos);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, class OpFunction, class... Ops>
|
||||
auto Operation<T,OpFunction,Ops...>::rootSteps(std::intptr_t iPtrNum) const
|
||||
-> decltype(PackNum<sizeof...(Ops)-1>::mkSteps(iPtrNum, mOps))
|
||||
|
@ -889,6 +971,14 @@ namespace MultiArrayTools
|
|||
return mOp.template get<ET>(pos);
|
||||
}
|
||||
|
||||
template <typename T, class Op, class IndexType>
|
||||
template <class ET>
|
||||
inline Contraction<T,Op,IndexType>& Contraction<T,Op,IndexType>::set(ET pos)
|
||||
{
|
||||
mOp.set(pos);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, class Op, class IndexType>
|
||||
auto Contraction<T,Op,IndexType>::rootSteps(std::intptr_t iPtrNum) const
|
||||
-> decltype(mOp.rootSteps(iPtrNum))
|
||||
|
@ -908,19 +998,30 @@ namespace MultiArrayTools
|
|||
**************************/
|
||||
|
||||
template <typename T, class Op, class... Indices>
|
||||
SliceContraction<T,Op,Indices...>::SliceContraction(const Op& op, const std::shared_ptr<Indices>&... ind) :
|
||||
SliceContraction<T,Op,Indices...>::SliceContraction(Op& op, std::shared_ptr<Indices>... ind) :
|
||||
mOp(op),
|
||||
mInds(ind...) {}
|
||||
mCont(ind->range()...),
|
||||
mTarOp(mCont,ind...) {}
|
||||
|
||||
// forward loop !!!!
|
||||
template <typename T, class Op, class... Indices>
|
||||
template <class ET>
|
||||
inline const MultiArray<T,Indices...>& SliceContraction<T,Op,Indices...>::get(ET pos) const
|
||||
inline const MultiArray<T,typename Indices::RangeType...>&
|
||||
SliceContraction<T,Op,Indices...>::get(ET pos) const
|
||||
{
|
||||
mTarOp(mInds) = mOp.set(pos); // SET FUNCTION!!
|
||||
mCont *= 0; // grrr
|
||||
mTarOp = mOp.set(pos); // SET FUNCTION!!
|
||||
return mCont;
|
||||
}
|
||||
|
||||
template <typename T, class Op, class... Indices>
|
||||
template <class ET>
|
||||
inline SliceContraction<T,Op,Indices...>& SliceContraction<T,Op,Indices...>::set(ET pos)
|
||||
{
|
||||
mOp.set(pos);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, class Op, class... Indices>
|
||||
auto SliceContraction<T,Op,Indices...>::rootSteps(std::intptr_t iPtrNum) const
|
||||
-> decltype(mOp.rootSteps(iPtrNum))
|
||||
|
|
|
@ -114,6 +114,16 @@ namespace MultiArrayHelper
|
|||
{
|
||||
return PackNum<N-1>::mkMapOp(ma, itp, std::get<N>(itp), inds...);
|
||||
}
|
||||
|
||||
template <size_t LAST,class OpTuple, class ETuple>
|
||||
static inline void setOpPos(const OpTuple& ot, const ETuple& et)
|
||||
{
|
||||
typedef typename std::remove_reference<decltype(std::get<N>(ot))>::type NextOpType;
|
||||
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
|
||||
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
|
||||
std::get<N>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
|
||||
PackNum<N-1>::template setOpPos<NEXT>(ot, et);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
|
@ -192,6 +202,15 @@ namespace MultiArrayHelper
|
|||
return ma.exec(std::get<0>(itp), inds...);
|
||||
}
|
||||
|
||||
template <size_t LAST,class OpTuple, class ETuple>
|
||||
static inline void setOpPos(const OpTuple& ot, const ETuple& et)
|
||||
{
|
||||
typedef typename std::remove_reference<decltype(std::get<0>(et))>::type NextOpType;
|
||||
static_assert(LAST > NextOpType::SIZE, "inconsistent array positions");
|
||||
static constexpr size_t NEXT = LAST - NextOpType::SIZE;
|
||||
std::get<0>( ot ).set( Getter<NEXT>::template getX<ETuple>( et ) );
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -668,7 +668,11 @@ namespace {
|
|||
MultiArray<double,MRange> ma1(mr1ptr, v3);
|
||||
MultiArray<double,SRange> ma2(sr2ptr, v1);
|
||||
MultiArray<double,SRange> ma3(sr4ptr, v4);
|
||||
MultiArray<double,SRange,SRange,SRange,SRange> ma4(sr1ptr,sr2ptr,sr3ptr,sr4ptr);
|
||||
MultiArray<double,SRange> ma5(sr1ptr, v3);
|
||||
MultiArray<double,SRange> ma6(sr3ptr, v2);
|
||||
|
||||
auto si0 = MAT::getIndex( sr1ptr );
|
||||
auto si1 = MAT::getIndex( sr2ptr );
|
||||
auto si2 = MAT::getIndex( sr3ptr );
|
||||
auto si3 = MAT::getIndex( sr4ptr );
|
||||
|
@ -676,6 +680,7 @@ namespace {
|
|||
mi->operator()(si1,si2);
|
||||
|
||||
res(mi,si3) = ma1(mi) + ma2(si1) + ma3(si3);
|
||||
ma4(si0,si1,si2,si3) = ma5(si0)*ma2(si1)*ma6(si2)*ma3(si3);
|
||||
|
||||
EXPECT_EQ( xround( res.at(mkt(mkt('1','a'),'A')) ), xround(0.353 + 2.917 + 1.470) );
|
||||
EXPECT_EQ( xround( res.at(mkt(mkt('1','a'),'B')) ), xround(0.353 + 2.917 + 2.210) );
|
||||
|
@ -691,7 +696,43 @@ namespace {
|
|||
EXPECT_EQ( xround( res.at(mkt(mkt('3','a'),'B')) ), xround(9.243 + 0.373 + 2.210) );
|
||||
EXPECT_EQ( xround( res.at(mkt(mkt('3','b'),'A')) ), xround(2.911 + 0.373 + 1.470) );
|
||||
EXPECT_EQ( xround( res.at(mkt(mkt('3','b'),'B')) ), xround(2.911 + 0.373 + 2.210) );
|
||||
}
|
||||
|
||||
MultiArray<MultiArray<double,SRange,SRange>,SRange,SRange> ma7(sr2ptr,sr4ptr);
|
||||
ma7(si1,si3) = ma4(si0,si1,si2,si3).slc(si0,si2);
|
||||
|
||||
si1->at('1');
|
||||
si3->at('A');
|
||||
Slice<double,SRange,SRange> sl(sr1ptr,sr3ptr);
|
||||
sl.define(si2,si3) = ma4(si0,si1,si2,si3);
|
||||
|
||||
EXPECT_EQ( xround( ma7.at(mkt('1','A')).at(mkt('x','a')) ),
|
||||
xround( ma4.at(mkt('x','1','a','A')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('2','A')).at(mkt('x','a')) ),
|
||||
xround( ma4.at(mkt('x','2','a','A')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('3','A')).at(mkt('x','a')) ),
|
||||
xround( ma4.at(mkt('x','3','a','A')) ) );
|
||||
|
||||
EXPECT_EQ( xround( ma7.at(mkt('1','A')).at(mkt('x','b')) ),
|
||||
xround( ma4.at(mkt('x','1','b','A')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('2','A')).at(mkt('x','b')) ),
|
||||
xround( ma4.at(mkt('x','2','b','A')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('3','A')).at(mkt('x','b')) ),
|
||||
xround( ma4.at(mkt('x','3','b','A')) ) );
|
||||
|
||||
EXPECT_EQ( xround( ma7.at(mkt('1','A')).at(mkt('l','b')) ),
|
||||
xround( ma4.at(mkt('l','1','b','A')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('2','A')).at(mkt('l','b')) ),
|
||||
xround( ma4.at(mkt('l','2','b','A')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('3','A')).at(mkt('l','b')) ),
|
||||
xround( ma4.at(mkt('l','3','b','A')) ) );
|
||||
|
||||
EXPECT_EQ( xround( ma7.at(mkt('1','B')).at(mkt('l','b')) ),
|
||||
xround( ma4.at(mkt('l','1','b','B')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('2','B')).at(mkt('l','b')) ),
|
||||
xround( ma4.at(mkt('l','2','b','B')) ) );
|
||||
EXPECT_EQ( xround( ma7.at(mkt('3','B')).at(mkt('l','b')) ),
|
||||
xround( ma4.at(mkt('l','3','b','B')) ) );
|
||||
}
|
||||
/*
|
||||
TEST_F(OpTest_MDim, ExecAnonOp1)
|
||||
{
|
||||
|
|
Loading…
Reference in a new issue