From 269ff69ec31b3ca27096aa390924be05e55924bd Mon Sep 17 00:00:00 2001 From: Christian Zimmermann Date: Sun, 24 Jan 2021 00:10:14 +0100 Subject: [PATCH] automatic vectorization: unit tests work again --- src/include/multi_array_operation.cc.h | 9 ++++---- src/include/multi_array_operation.h | 2 +- src/include/xfor/xfor.h | 32 +++++++++++++++++++++----- src/tests/op3_unit_test.cc | 1 + 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/include/multi_array_operation.cc.h b/src/include/multi_array_operation.cc.h index 64fdecd..dea6150 100644 --- a/src/include/multi_array_operation.cc.h +++ b/src/include/multi_array_operation.cc.h @@ -347,7 +347,7 @@ namespace MultiArrayTools inline const V& ConstOperationRoot::vget(ET pos) const { VCHECK(pos.val()); - return *(reinterpret_cast(mDataPtr)+pos.val()); + return *(reinterpret_cast(mDataPtr+pos.val())); } template @@ -614,7 +614,8 @@ namespace MultiArrayTools CHECK; typedef typename TarOp::value_type T; auto x = th.template asx>>(in); - if(x.divResid() == 0){ + if(x.rootSteps(x.vI()) == 1){ + //if(0){ CHECK; x(); } @@ -665,7 +666,7 @@ namespace MultiArrayTools inline V& OperationRoot::vget(ET pos) const { VCHECK(pos.val()); - return *(reinterpret_cast(mDataPtr)+pos.val()); + return *(reinterpret_cast(mDataPtr+pos.val())); } template @@ -843,7 +844,7 @@ namespace MultiArrayTools template inline V& ParallelOperationRoot::vget(ET pos) const { - return *(reinterpret_cast(mDataPtr)+pos.val()); + return *(reinterpret_cast(mDataPtr+pos.val())); } template diff --git a/src/include/multi_array_operation.h b/src/include/multi_array_operation.h index de33740..3ac063e 100644 --- a/src/include/multi_array_operation.h +++ b/src/include/multi_array_operation.h @@ -260,7 +260,7 @@ namespace MultiArrayTools static inline void f(T*& t, size_t pos, const Op& op, ExtType e) { VCHECK(pos); - VFunc::selfApply(reinterpret_cast(t)[pos],op.template vget(e)); + VFunc::selfApply(*reinterpret_cast(t+pos),op.template vget(e)); } }; diff --git a/src/include/xfor/xfor.h b/src/include/xfor/xfor.h index c9f6584..d7b7e1e 100644 --- a/src/include/xfor/xfor.h +++ b/src/include/xfor/xfor.h @@ -187,8 +187,9 @@ namespace MultiArrayHelper ExpressionBase& operator=(const ExpressionBase& in) = default; ExpressionBase& operator=(ExpressionBase&& in) = default; - virtual size_t divResid() const { return 0; } - + //virtual size_t divResid() const { return 1; } + virtual std::intptr_t vI() const { return 0; } + virtual std::shared_ptr deepCopy() const = 0; virtual void operator()(size_t mlast, DExt last) = 0; @@ -450,7 +451,16 @@ namespace MultiArrayHelper return std::make_shared>(*this); } - virtual size_t divResid() const override final { return mMax % DIV + MkVExpr::divResid(mExpr); } + //virtual size_t divResid() const override final { return mMax % DIV + MkVExpr::divResid(mExpr); } + + virtual std::intptr_t vI() const override final + { + if(mStep == 1 and LAYER == 1 and mMax % DIV == 0){ + VCHECK(LAYER); + return reinterpret_cast(mIndPtr); + } + return mExpr.vI(); + } template auto vec() const @@ -511,7 +521,15 @@ namespace MultiArrayHelper PFor(const IndexClass* indPtr, size_t step, Expr expr); - virtual size_t divResid() const override final { return mMax % DIV + MkVExpr::divResid(mExpr); } + //virtual size_t divResid() const override final { return mMax % DIV + MkVExpr::divResid(mExpr); } + virtual std::intptr_t vI() const override final + { + if(mStep == 1 and LAYER == 1 and mMax % DIV == 0){ + VCHECK(LAYER); + return reinterpret_cast(mIndPtr); + } + return mExpr.vI(); + } template auto vec() const @@ -692,7 +710,9 @@ namespace MultiArrayHelper mIndPtr(indPtr), mSPos(mIndPtr->pos()), mMax(mIndPtr->max()), mStep(step), mExpr(expr), mExt(mExpr.rootSteps( reinterpret_cast( mIndPtr ))) { - assert(mMax % DIV == 0); + //VCHECK(mMax); + //VCHECK(DIV); + //assert(mMax % DIV == 0); assert(mIndPtr != nullptr); } @@ -777,7 +797,7 @@ namespace MultiArrayHelper mIndPtr(indPtr.get()), mSPos(mIndPtr->pos()), mMax(mIndPtr->max()), mStep(step), mExpr(expr), mExt(mExpr.rootSteps( reinterpret_cast( mIndPtr ))) { - assert(mMax % DIV == 0); + //assert(mMax % DIV == 0); assert(mIndPtr != nullptr); } diff --git a/src/tests/op3_unit_test.cc b/src/tests/op3_unit_test.cc index 876f9fd..e211634 100644 --- a/src/tests/op3_unit_test.cc +++ b/src/tests/op3_unit_test.cc @@ -45,6 +45,7 @@ namespace std::clock_t begin = std::clock(); res1(delta, deltap) += ma(delta, alpha, alpha, beta, beta, gamma, gamma, deltap).c(mix); + //res1(delta, deltap) += ma(delta, alpha, alpha, beta, beta, gamma, gamma, deltap); std::clock_t end = std::clock(); std::cout << "MultiArray time: " << static_cast( end - begin ) / CLOCKS_PER_SEC << std::endl;