automatic vectorization: unit tests work again
This commit is contained in:
parent
3bacc6c1c4
commit
269ff69ec3
4 changed files with 33 additions and 11 deletions
|
@ -347,7 +347,7 @@ namespace MultiArrayTools
|
|||
inline const V& ConstOperationRoot<T,Ranges...>::vget(ET pos) const
|
||||
{
|
||||
VCHECK(pos.val());
|
||||
return *(reinterpret_cast<const V*>(mDataPtr)+pos.val());
|
||||
return *(reinterpret_cast<const V*>(mDataPtr+pos.val()));
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
|
@ -614,7 +614,8 @@ namespace MultiArrayTools
|
|||
CHECK;
|
||||
typedef typename TarOp::value_type T;
|
||||
auto x = th.template asx<IVAccess<T,F<T>>>(in);
|
||||
if(x.divResid() == 0){
|
||||
if(x.rootSteps(x.vI()) == 1){
|
||||
//if(0){
|
||||
CHECK;
|
||||
x();
|
||||
}
|
||||
|
@ -665,7 +666,7 @@ namespace MultiArrayTools
|
|||
inline V& OperationRoot<T,Ranges...>::vget(ET pos) const
|
||||
{
|
||||
VCHECK(pos.val());
|
||||
return *(reinterpret_cast<V*>(mDataPtr)+pos.val());
|
||||
return *(reinterpret_cast<V*>(mDataPtr+pos.val()));
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
|
@ -843,7 +844,7 @@ namespace MultiArrayTools
|
|||
template <typename V, class ET>
|
||||
inline V& ParallelOperationRoot<T,Ranges...>::vget(ET pos) const
|
||||
{
|
||||
return *(reinterpret_cast<const V*>(mDataPtr)+pos.val());
|
||||
return *(reinterpret_cast<V*>(mDataPtr+pos.val()));
|
||||
}
|
||||
|
||||
template <typename T, class... Ranges>
|
||||
|
|
|
@ -260,7 +260,7 @@ namespace MultiArrayTools
|
|||
static inline void f(T*& t, size_t pos, const Op& op, ExtType e)
|
||||
{
|
||||
VCHECK(pos);
|
||||
VFunc<F>::selfApply(reinterpret_cast<value_type*>(t)[pos],op.template vget<value_type>(e));
|
||||
VFunc<F>::selfApply(*reinterpret_cast<value_type*>(t+pos),op.template vget<value_type>(e));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -187,8 +187,9 @@ namespace MultiArrayHelper
|
|||
ExpressionBase& operator=(const ExpressionBase& in) = default;
|
||||
ExpressionBase& operator=(ExpressionBase&& in) = default;
|
||||
|
||||
virtual size_t divResid() const { return 0; }
|
||||
|
||||
//virtual size_t divResid() const { return 1; }
|
||||
virtual std::intptr_t vI() const { return 0; }
|
||||
|
||||
virtual std::shared_ptr<ExpressionBase> deepCopy() const = 0;
|
||||
|
||||
virtual void operator()(size_t mlast, DExt last) = 0;
|
||||
|
@ -450,7 +451,16 @@ namespace MultiArrayHelper
|
|||
return std::make_shared<For<IndexClass,Expr,FT,DIV>>(*this);
|
||||
}
|
||||
|
||||
virtual size_t divResid() const override final { return mMax % DIV + MkVExpr<LAYER>::divResid(mExpr); }
|
||||
//virtual size_t divResid() const override final { return mMax % DIV + MkVExpr<LAYER>::divResid(mExpr); }
|
||||
|
||||
virtual std::intptr_t vI() const override final
|
||||
{
|
||||
if(mStep == 1 and LAYER == 1 and mMax % DIV == 0){
|
||||
VCHECK(LAYER);
|
||||
return reinterpret_cast<std::intptr_t>(mIndPtr);
|
||||
}
|
||||
return mExpr.vI();
|
||||
}
|
||||
|
||||
template <size_t VS>
|
||||
auto vec() const
|
||||
|
@ -511,7 +521,15 @@ namespace MultiArrayHelper
|
|||
PFor(const IndexClass* indPtr,
|
||||
size_t step, Expr expr);
|
||||
|
||||
virtual size_t divResid() const override final { return mMax % DIV + MkVExpr<LAYER>::divResid(mExpr); }
|
||||
//virtual size_t divResid() const override final { return mMax % DIV + MkVExpr<LAYER>::divResid(mExpr); }
|
||||
virtual std::intptr_t vI() const override final
|
||||
{
|
||||
if(mStep == 1 and LAYER == 1 and mMax % DIV == 0){
|
||||
VCHECK(LAYER);
|
||||
return reinterpret_cast<std::intptr_t>(mIndPtr);
|
||||
}
|
||||
return mExpr.vI();
|
||||
}
|
||||
|
||||
template <size_t VS>
|
||||
auto vec() const
|
||||
|
@ -692,7 +710,9 @@ namespace MultiArrayHelper
|
|||
mIndPtr(indPtr), mSPos(mIndPtr->pos()), mMax(mIndPtr->max()), mStep(step),
|
||||
mExpr(expr), mExt(mExpr.rootSteps( reinterpret_cast<std::intptr_t>( mIndPtr )))
|
||||
{
|
||||
assert(mMax % DIV == 0);
|
||||
//VCHECK(mMax);
|
||||
//VCHECK(DIV);
|
||||
//assert(mMax % DIV == 0);
|
||||
assert(mIndPtr != nullptr);
|
||||
}
|
||||
|
||||
|
@ -777,7 +797,7 @@ namespace MultiArrayHelper
|
|||
mIndPtr(indPtr.get()), mSPos(mIndPtr->pos()), mMax(mIndPtr->max()), mStep(step),
|
||||
mExpr(expr), mExt(mExpr.rootSteps( reinterpret_cast<std::intptr_t>( mIndPtr )))
|
||||
{
|
||||
assert(mMax % DIV == 0);
|
||||
//assert(mMax % DIV == 0);
|
||||
assert(mIndPtr != nullptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ namespace
|
|||
|
||||
std::clock_t begin = std::clock();
|
||||
res1(delta, deltap) += ma(delta, alpha, alpha, beta, beta, gamma, gamma, deltap).c(mix);
|
||||
//res1(delta, deltap) += ma(delta, alpha, alpha, beta, beta, gamma, gamma, deltap);
|
||||
std::clock_t end = std::clock();
|
||||
std::cout << "MultiArray time: " << static_cast<double>( end - begin ) / CLOCKS_PER_SEC
|
||||
<< std::endl;
|
||||
|
|
Loading…
Reference in a new issue