xfor: vectrization requirements

This commit is contained in:
Christian Zimmermann 2021-01-14 17:40:08 +01:00
parent 6ecbe5ff27
commit fdb1bb6833
4 changed files with 54 additions and 5 deletions

View file

@ -27,6 +27,18 @@ namespace MultiArrayTools
assert(mIndPtr != nullptr);
}
template <class Op, class Index, class Expr, SpaceType STYPE>
std::shared_ptr<ExpressionBase> OpExpr<Op,Index,Expr,STYPE>::deepCopy() const
{
return std::make_shared<OpExpr<Op,Index,Expr,STYPE>>(*this);
}
template <class Op, class Index, class Expr, SpaceType STYPE>
inline void OpExpr<Op,Index,Expr,STYPE>::operator()(size_t mlast, DExt last)
{
operator()(mlast, std::dynamic_pointer_cast<ExtT<ExtType>>(last)->ext());
}
template <class Op, class Index, class Expr, SpaceType STYPE>
inline void OpExpr<Op,Index,Expr,STYPE>::operator()(size_t mlast,
ExtType last)
@ -63,7 +75,18 @@ namespace MultiArrayTools
//return mExpr.rootSteps(iPtrNum).extend( mOp.rootSteps(iPtrNum) );
}
template <class Op, class Index, class Expr, SpaceType STYPE>
DExt OpExpr<Op,Index,Expr,STYPE>::dRootSteps(std::intptr_t iPtrNum) const
{
return std::make_shared<ExtT<ExtType>>(rootSteps(iPtrNum));
}
template <class Op, class Index, class Expr, SpaceType STYPE>
DExt OpExpr<Op,Index,Expr,STYPE>::dExtension() const
{
return std::make_shared<ExtT<ExtType>>(mExt);
}
// -> define in range_base.cc
//std::shared_ptr<RangeFactoryBase> mkMULTI(const char** dp);

View file

@ -50,7 +50,7 @@ namespace MultiArrayTools
template <class Op, class Index, class Expr, SpaceType STYPE = SpaceType::ANY>
//template <class MapF, class IndexPack, class Expr, SpaceType STYPE = SpaceType::ANY>
class OpExpr
class OpExpr : public ExpressionBase
{
public:
//typedef typename Index::OIType OIType;
@ -80,11 +80,17 @@ namespace MultiArrayTools
OpExpr(const Op& mapf, const Index* ind, size_t step, Expr ex);
virtual std::shared_ptr<ExpressionBase> deepCopy() const override final;
inline void operator()(size_t mlast, DExt last) override final;
inline void operator()(size_t mlast, ExtType last);
inline void operator()(size_t mlast = 0);
inline void operator()(size_t mlast = 0) override final;
auto rootSteps(std::intptr_t iPtrNum = 0) const -> ExtType;
virtual DExt dRootSteps(std::intptr_t iPtrNum = 0) const override final;
virtual DExt dExtension() const override final;
};
template <class OIType, class Op, SpaceType XSTYPE, class... Indices>

View file

@ -401,7 +401,7 @@ namespace MultiArrayTools
template <typename V, class ET>
inline const V& ConstOperationRoot<T,Ranges...>::vget(ET pos) const
{
return *reinterpret_cast<const V*>(mDataPtr+pos.val());
return *(reinterpret_cast<const V*>(mDataPtr)+pos.val());
}
template <typename T, class... Ranges>
@ -668,7 +668,7 @@ namespace MultiArrayTools
template <typename V, class ET>
inline V& OperationRoot<T,Ranges...>::vget(ET pos) const
{
return *reinterpret_cast<V*>(mDataPtr + pos.val());
return *(reinterpret_cast<const V*>(mDataPtr)+pos.val());
}
template <typename T, class... Ranges>
@ -815,7 +815,7 @@ namespace MultiArrayTools
template <typename V, class ET>
inline V& ParallelOperationRoot<T,Ranges...>::vget(ET pos) const
{
return *reinterpret_cast<V*>(mDataPtr+pos.val());
return *(reinterpret_cast<const V*>(mDataPtr)+pos.val());
}
template <typename T, class... Ranges>

View file

@ -173,6 +173,8 @@ namespace MultiArrayHelper
ExpressionBase& operator=(const ExpressionBase& in) = default;
ExpressionBase& operator=(ExpressionBase&& in) = default;
virtual std::intptr_t vec(size_t vs) { return 0; }
virtual std::shared_ptr<ExpressionBase> deepCopy() const = 0;
virtual void operator()(size_t mlast, DExt last) = 0;
@ -377,6 +379,15 @@ namespace MultiArrayHelper
return std::make_shared<For<IndexClass,Expr,FT>>(*this);
}
virtual std::intptr_t vec(size_t vs) override final
{
if(mStep == 1 and mMax % vs == 0){
mMax /= vs;
return reinterpret_cast<std::intptr_t>(mIndPtr);
}
return mExpr.vec(vs);
}
inline void operator()(size_t mlast, DExt last) override final;
inline void operator()(size_t mlast, ExtType last) ;
inline void operator()(size_t mlast = 0) override final;
@ -425,6 +436,15 @@ namespace MultiArrayHelper
PFor(const IndexClass* indPtr,
size_t step, Expr expr);
virtual std::intptr_t vec(size_t vs) override final
{
if(mStep == 1 and mMax % vs == 0){
mMax /= vs;
return reinterpret_cast<std::intptr_t>(mIndPtr);
}
return mExpr.vec(vs);
}
virtual std::shared_ptr<ExpressionBase> deepCopy() const override final
{
return std::make_shared<PFor<IndexClass,Expr>>(*this);