From 3a7bd9c9e2d918ed6a7f3a38c63c411aa13677c6 Mon Sep 17 00:00:00 2001 From: Christian Zimmermann Date: Mon, 30 Jan 2023 01:13:14 +0100 Subject: [PATCH] fix urange slice() + prange for and pfor expression --- src/include/ranges/prange.cc.h | 11 +++-- src/include/ranges/urange.cc.h | 4 +- src/include/xpr/for.cc.h | 88 ++++++++++++++++++++++++++++++++++ src/include/xpr/for.h | 37 ++++++++++++++ 4 files changed, 135 insertions(+), 5 deletions(-) diff --git a/src/include/ranges/prange.cc.h b/src/include/ranges/prange.cc.h index 6637919..c283682 100644 --- a/src/include/ranges/prange.cc.h +++ b/src/include/ranges/prange.cc.h @@ -153,21 +153,26 @@ namespace CNORXZ template decltype(auto) PIndex::format(const Sptr& ind) const { - /*!!!*/ + return ind; } template template decltype(auto) PIndex::slice(const Sptr& ind) const { - /*!!!*/ + if(ind != nullptr){ + if(ind->dim() != 0){ + return Sptr>(); + } + } + return std::make_shared>(*this); } template template decltype(auto) PIndex::ifor(const Xpr& xpr, F&& f) const { - return For<0,Xpr,F>(this->pmax().val(), this->id(), xpr, std::forward(f)); + return PFor<0,0,Xpr,F>(this->lmax().val(), this->id(), mOrig->id(), xpr, std::forward(f)); } template diff --git a/src/include/ranges/urange.cc.h b/src/include/ranges/urange.cc.h index 4964619..87b877d 100644 --- a/src/include/ranges/urange.cc.h +++ b/src/include/ranges/urange.cc.h @@ -158,10 +158,10 @@ namespace CNORXZ { if(ind != nullptr){ if(ind->dim() != 0) { - return Sptr(); + return Sptr>(); } } - return std::make_shared(*this); + return std::make_shared>(*this); } template diff --git a/src/include/xpr/for.cc.h b/src/include/xpr/for.cc.h index de2cdb3..90d861c 100644 --- a/src/include/xpr/for.cc.h +++ b/src/include/xpr/for.cc.h @@ -207,6 +207,94 @@ namespace CNORXZ return SFor(id, xpr, NoF {}); } + /************ + * PFor * + ************/ + + template + constexpr PFor::PFor(SizeT size, const IndexId& id1, const IndexId& id2, + const SizeT* map, const Xpr& xpr, F&& f) : + mSize(size), + mId1(id1), + mId2(id2), + mXpr(xpr), + mExt1(mXpr.rootSteps(mId1)), + mExt2(mXpr.rootSteps(mId2)), + mPart(1, map), + mF(f) + {} + + template + template + inline decltype(auto) PFor::operator()(const PosT& last) const + { + if constexpr(std::is_same::type,NoF>::value){ + for(SizeT i = 0; i != mSize; ++i){ + const auto pos1 = last + mExt1( UPos(i) ); + const auto pos2 = pos1 + mExt2( mPart( UPos(i) ) ); + mXpr(pos2); + } + } + else { + typedef typename + std::remove_reference::type OutT; + auto o = OutT(); + for(SizeT i = 0; i != mSize; ++i){ + const auto pos1 = last + mExt1( UPos(i) ); + const auto pos2 = pos1 + mExt2( mPart( UPos(i) ) ); + mF(o, mXpr(pos2)); + } + return o; + } + } + + template + inline decltype(auto) PFor::operator()() const + { + if constexpr(std::is_same::type,NoF>::value){ + for(SizeT i = 0; i != mSize; ++i){ + const auto pos1 = mExt1( UPos(i) ); + const auto pos2 = pos1 + mExt2( mPart( UPos(i) ) ); + mXpr(pos2); + } + } + else { + typedef typename std::remove_reference::type OutT; + auto o = OutT(); + for(SizeT i = 0; i != mSize; ++i){ + const auto pos1 = mExt1( UPos(i) ); + const auto pos2 = pos1 + mExt2( mPart( UPos(i) ) ); + mF(o, mXpr(pos2)); + } + return o; + } + } + + template + template + inline decltype(auto) PFor::rootSteps(const IndexId& id) const + { + return mXpr.rootSteps(id); + } + + /************************* + * PFor (non-member) * + *************************/ + + template + constexpr decltype(auto) mkPFor(SizeT size, const IndexId& id1, const IndexId& id2, + const Xpr& xpr, F&& f) + { + return PFor(size, id1, id2, xpr, std::forward(f)); + } + + template + constexpr decltype(auto) mkPFor(SizeT size, const IndexId& id1, const IndexId& id2, + const Xpr& xpr) + { + return PFor(size, id1, id2, xpr, NoF {}); + } + /************ * TFor * ************/ diff --git a/src/include/xpr/for.h b/src/include/xpr/for.h index 06709c2..1ef46fd 100644 --- a/src/include/xpr/for.h +++ b/src/include/xpr/for.h @@ -85,6 +85,43 @@ namespace CNORXZ template constexpr decltype(auto) mkSFor(const IndexId& id, const Xpr& xpr); + // partial for: + template + class PFor : public XprInterface> + { + public: + DEFAULT_MEMBERS(PFor); + + constexpr PFor(SizeT size, const IndexId& id1, const IndexId& id2, + const SizeT* map, const Xpr& xpr, F&& f); + + template + inline decltype(auto) operator()(const PosT& last) const; + + inline decltype(auto) operator()() const; + + template + inline decltype(auto) rootSteps(const IndexId& id) const; + + private: + SizeT mSize = 0; + IndexId mId1; + IndexId mId2; + Xpr mXpr; + typedef decltype(mXpr.rootSteps(mId1)) XPosT1; + typedef decltype(mXpr.rootSteps(mId2)) XPosT2; + XPosT1 mExt1; + XPosT2 mExt2; + FPos mPart; + F mF; + }; + + template + constexpr decltype(auto) mkFor(SizeT size, const IndexId& id, const Xpr& xpr, F&& f); + + template + constexpr decltype(auto) mkFor(SizeT size, const IndexId& id, const Xpr& xpr); + // multi-threading template class TFor : public XprInterface>