diff --git a/src/include/xpr/for.cc.h b/src/include/xpr/for.cc.h index 488bc2d..1d2edf8 100644 --- a/src/include/xpr/for.cc.h +++ b/src/include/xpr/for.cc.h @@ -26,25 +26,42 @@ namespace CNORXZ template inline decltype(auto) For::operator()(const PosT& last) const { - typedef typename std::remove_reference::type OutT; - auto o = OutT(); - for(SizeT i = 0; i != mSize; ++i){ - const auto pos = last + mExt * UPos(i); - mF(o, mXpr(pos)); + if constexpr(std::is_same::value){ + for(SizeT i = 0; i != mSize; ++i){ + const auto pos = last + mExt * UPos(i); + mXpr(pos); + } + } + else { + typedef typename + std::remove_reference::type OutT; + auto o = OutT(); + for(SizeT i = 0; i != mSize; ++i){ + const auto pos = last + mExt * UPos(i); + mF(o, mXpr(pos)); + } + return o; } - return o; } template inline decltype(auto) For::operator()() const { - typedef typename std::remove_reference::type OutT; - auto o = OutT(); - for(SizeT i = 0; i != mSize; ++i){ - const auto pos = mExt * UPos(i); - mF(o, mXpr(pos)); + if constexpr(std::is_same::value){ + for(SizeT i = 0; i != mSize; ++i){ + const auto pos = mExt * UPos(i); + mXpr(pos); + } + } + else { + typedef typename std::remove_reference::type OutT; + auto o = OutT(); + for(SizeT i = 0; i != mSize; ++i){ + const auto pos = mExt * UPos(i); + mF(o, mXpr(pos)); + } + return o; } - return o; } template @@ -54,6 +71,21 @@ namespace CNORXZ return mXpr.rootSteps(id); } + /************************ + * For (non-member) * + ************************/ + + template + constexpr decltype(auto) mkFor(SizeT size, const IndexId& id, const Xpr& xpr, F&& f) + { + return For(size, id, xpr, std::forward(f)); + } + + template + constexpr decltype(auto) mkFor(SizeT size, const IndexId& id, const Xpr& xpr) + { + return For(size, id, xpr, NoF {}); + } /************ * SFor * @@ -63,7 +95,7 @@ namespace CNORXZ constexpr SFor::SFor(const IndexId& id, const Xpr& xpr, F&& f) : mId(id), mXpr(xpr), - mExt(mXpr.RootSteps(mId)), + mExt(mXpr.rootSteps(mId)), mF(f) {} @@ -71,13 +103,25 @@ namespace CNORXZ template constexpr decltype(auto) SFor::operator()(const PosT& last) const { - return exec<0>(last); + if constexpr(std::is_same::value){ + exec2<0>(last); + return; + } + else { + return exec<0>(last); + } } template constexpr decltype(auto) SFor::operator()() const { - return exec<0>(); + if constexpr(std::is_same::value){ + exec2<0>(); + return; + } + else { + return exec<0>(); + } } template @@ -93,12 +137,11 @@ namespace CNORXZ { constexpr SPos i; const auto pos = last + mExt * i; - auto o = mXpr(pos); if constexpr(I < N-1){ - return mF(o,exec(last)); + return mF(mXpr(pos),exec(last)); } else { - return o; + return mXpr(pos); } } @@ -108,15 +151,62 @@ namespace CNORXZ { constexpr SPos i; const auto pos = mExt * i; - auto o = mXpr(pos); if constexpr(I < N-1){ - return mF(o,exec()); + return mF(mXpr(pos),exec()); } else { - return o; + return mXpr(pos); } } + template + template + inline void SFor::exec2(const PosT& last) const + { + constexpr SPos i; + const auto pos = last + mExt * i; + if constexpr(I < N-1){ + mXpr(pos); + exec2(last); + } + else { + mXpr(pos); + } + return; + } + + template + template + inline void SFor::exec2() const + { + constexpr SPos i; + const auto pos = mExt * i; + if constexpr(I < N-1){ + mXpr(pos); + exec2(); + } + else { + mXpr(pos); + } + return; + } + + /************************* + * SFor (non-member) * + *************************/ + + template + constexpr decltype(auto) mkSFor(const IndexId& id, const Xpr& xpr, F&& f) + { + return SFor(id, xpr, std::forward(f)); + } + + template + constexpr decltype(auto) mkSFor(const IndexId& id, const Xpr& xpr) + { + return SFor(id, xpr, NoF {}); + } + /************ * TFor * ************/ diff --git a/src/include/xpr/for.h b/src/include/xpr/for.h index d4ea36c..06709c2 100644 --- a/src/include/xpr/for.h +++ b/src/include/xpr/for.h @@ -34,7 +34,12 @@ namespace CNORXZ F mF; }; - + template + constexpr decltype(auto) mkFor(SizeT size, const IndexId& id, const Xpr& xpr, F&& f); + + template + constexpr decltype(auto) mkFor(SizeT size, const IndexId& id, const Xpr& xpr); + // unrolled loop: template class SFor : public XprInterface> @@ -60,14 +65,25 @@ namespace CNORXZ template constexpr decltype(auto) exec() const; + template + inline void exec2(const PosT& last) const; + + template + inline void exec2() const; + IndexId mId; Xpr mXpr; - typedef decltype(mXpr.RootSteps(mId)) XPosT; + typedef decltype(mXpr.rootSteps(mId)) XPosT; XPosT mExt; F mF; }; + template + constexpr decltype(auto) mkSFor(const IndexId& id, const Xpr& xpr, F&& f); + + template + constexpr decltype(auto) mkSFor(const IndexId& id, const Xpr& xpr); // multi-threading template diff --git a/src/tests/xpr_unit_test.cc b/src/tests/xpr_unit_test.cc index bb2280d..8e72fec 100644 --- a/src/tests/xpr_unit_test.cc +++ b/src/tests/xpr_unit_test.cc @@ -34,6 +34,73 @@ namespace SPos mS2p; }; + class For_Test : public ::testing::Test + { + protected: + + class TestXpr1 + { + public: + constexpr TestXpr1(const IndexId<0>& id) : mId(id) {} + + template + inline SizeT operator()(const PosT& last) const + { + const SizeT o = 1u; + return o << last.val(); + } + + template + inline UPos rootSteps(const IndexId& id) const + { + return UPos( mId == id ? 1u : 0u ); + } + + private: + IndexId<0> mId; + }; + + class TestXpr2 + { + public: + constexpr TestXpr2(const IndexId<0>& id, SizeT size) : + mId(id), mSize(size), mCnt(size) {} + + template + inline void operator()(const PosT& last) const + { + --mCnt; + EXPECT_EQ(mCnt, mSize-last.val()-1); + } + + template + inline UPos rootSteps(const IndexId& id) const + { + return UPos( mId == id ? 1u : 0u ); + } + + private: + IndexId<0> mId; + SizeT mSize; + mutable SizeT mCnt; + }; + + static constexpr SizeT sSize = 7u; + + For_Test() + { + mSize = sSize; + mId1 = 10u; + mId2 = 11u; + mId3 = 12u; + } + + SizeT mSize; + PtrId mId1; + PtrId mId2; + PtrId mId3; + }; + TEST_F(Pos_Test, Basics) { EXPECT_EQ( mUp1.size(), 1u ); @@ -157,6 +224,37 @@ namespace EXPECT_EQ(dp5.sub().val(), mS4p.val() * mUp1.val()); } + TEST_F(For_Test, For) + { + auto loop = mkFor(mSize, IndexId<0>(mId1), TestXpr1( IndexId<0>(mId1) ), + [](auto& o, const auto& e) { o += e; }); + + const UPos rs = loop.rootSteps(IndexId<0>(mId1)); + EXPECT_EQ(rs.val(), 1u); + const UPos rs2 = loop.rootSteps(IndexId<0>(mId2)); + EXPECT_EQ(rs2.val(), 0u); + const SizeT res = loop(); + EXPECT_EQ(res, (1u << mSize) - 1u); + + auto loop2 = mkFor(mSize, IndexId<0>(mId1), TestXpr2( IndexId<0>(mId1), mSize )); + loop2(); + } + + TEST_F(For_Test, SFor) + { + auto loop = mkSFor(IndexId<0>(mId1), TestXpr1( IndexId<0>(mId1) ), + [](const auto& a, const auto& b) { return a + b; }); + + const UPos rs = loop.rootSteps(IndexId<0>(mId1)); + EXPECT_EQ(rs.val(), 1u); + const UPos rs2 = loop.rootSteps(IndexId<0>(mId2)); + EXPECT_EQ(rs2.val(), 0u); + const SizeT res = loop(); + EXPECT_EQ(res, (1u << mSize) - 1u); + + auto loop2 = mkSFor(IndexId<0>(mId1), TestXpr2( IndexId<0>(mId1), mSize )); + loop2(); + } } int main(int argc, char** argv)