From 6edd39de372864d954cf3fef0b8f54d90619e393 Mon Sep 17 00:00:00 2001 From: Christian Zimmermann Date: Tue, 12 Mar 2024 00:44:23 +0100 Subject: [PATCH] WIP: RRange --- src/opt/mpi/include/rrange.cc.h | 149 +++++++++++++++++++++++++------- src/opt/mpi/include/rrange.h | 56 ++++++++---- 2 files changed, 157 insertions(+), 48 deletions(-) diff --git a/src/opt/mpi/include/rrange.cc.h b/src/opt/mpi/include/rrange.cc.h index d3152fa..e2663d4 100644 --- a/src/opt/mpi/include/rrange.cc.h +++ b/src/opt/mpi/include/rrange.cc.h @@ -19,6 +19,10 @@ namespace CNORXZ namespace mpi { + /*==============+ + | RIndex | + +==============*/ + template RIndex::RIndex(const RIndex& in) : mRange(in.mRange), @@ -50,193 +54,274 @@ namespace CNORXZ template RIndex::RIndex(const Sptr& local) : { + CXZ_ERROR("not implemented"); //!!! } template RIndex& RIndex::operator=(SizeT pos) { + IB::mPos = pos; // = lex + if(lexpos >= lmax().val()){ + IB::mPos = pmax().val(); + return *this; + } // pos is the lexicographic position of the global range. // Hence, have to consider the rank geometry. - if constexpr(has_static_sub::value or has_static_sub::value){ - + const auto& i = mI->pack(); + const auto& k = mK->pack(); + const auto& ilf = mI->lexFormat(); + const auto& klf = mK->lexFormat(); + SizeT r = 0; + SizeT l = 0; + if constexpr(has_static_sub::value){ + constexpr SizeT NI = index_dim::value; + iter<0,NI>( [&](auto mu) { + const SizeT jmu = (IB::mPos / ilf[mu].val()*klf[mu].val()) % + i[mu]->lmax().val()*k[mu]->lmax().val(); + r += ( jmu / i[mu]->lmax().val() ) * klf[mu].val(); + l += ( jmu % i[mu]->lmax().val() ) * ilf[mu].val(); + }, NoF{} ); + } + else if constexpr( has_static_sub::value){ + constexpr SizeT NI = index_dim::value; + iter<0,NI>( [&](auto mu) { + const SizeT jmu = (IB::mPos / ilf[mu].val()*klf[mu].val()) % + i[mu]->lmax().val()*k[mu]->lmax().val(); + r += ( jmu / i[mu]->lmax().val() ) * klf[mu].val(); + l += ( jmu % i[mu]->lmax().val() ) * ilf[mu].val(); + }, NoF{} ); } else { - + const SizeT NI = mI->dim(); + for(SizeT mu = 0; mu != NI; ++mu){ + const SizeT jmu = (IB::mPos / ilf[mu].val()*klf[mu].val()) % + i[mu]->lmax().val()*k[mu]->lmax().val(); + r += ( jmu / i[mu]->lmax().val() ) * klf[mu].val(); + l += ( jmu % i[mu]->lmax().val() ) * ilf[mu].val(); + } } + *mI = l; + *mK = r; return *this; } template RIndex& RIndex::operator++() { - + *this = lex() + 1; // room for optimization return *this; } template RIndex& RIndex::operator--() { - + *this = lex() - 1; // room for optimization return *this; } template RIndex RIndex::operator+(Int n) const { - + RIndex o(*this); + return o += n; } template RIndex RIndex::operator-(Int n) const { - + RIndex o(*this); + return o -= n; } template SizeT RIndex::operator-(const RIndex& i) const { - + return lex() - i.lex(); } template RIndex& RIndex::operator+=(Int n) { - + *this = lex() + n; return *this; } template RIndex& RIndex::operator-=(Int n) { - + *this = lex() - n; return *this; } template SizeT RIndex::lex() const { - + return IB::mPos; } template constexpr RIndex::decltype(auto) pmax() const { - + return mK->lmax().val() * mI->lmax().val(); } - template constexpr RIndex::decltype(auto) lmax() const { - + return mK->lmax().val() * mI->lmax().val(); } template IndexId<0> RIndex::id() const { - + return IndexId<0>(this->ptrId()); } template MetaType RIndex::operator*() const { - + return meta(); } template constexpr SizeT RIndex::dim() const { - + return mI->dim(); } template Sptr RIndex::range() const { - + return mRange; } template template decltype(auto) RIndex::stepSize(const IndexId& id) const { - + return mK->stepSize(id) * mI->lmax().val() + mI->stepSize(id); } template String RIndex::stringMeta() const { - + const SizeT r = mK->lex(); + String o; + broadcast(r, mI->stringMeta(), &o); + return o; } template MetaType RIndex::meta() const { - + const SizeT r = mK->lex(); + MetaType o; + broadcast(r, mI->meta(), &o); + return o; } template RIndex& RIndex::at(const MetaType& metaPos) { - + CXZ_ERROR("not implemented"); + return *this; } template RangePtr RIndex::prange(const RIndex& last) const { - + CXZ_ERROR("not implemented"); + return nullptr; } template auto RIndex::deepFormat() const { - + return concat( mul(mK->deepFormat(), mI->lmax().val() ), mI->deepFormat() ); } template auto RIndex::deepMax() const { - + return concat( mK->deepMax(), mI->deepMax() ); } template RIndex& RIndex::reformat(const Vector& f, const Vector& s) { - + CXZ_ERROR("not implemented"); + return *this; } template template constexpr decltype(auto) RIndex::ifor(const Xpr& xpr, F&& f) const { - + CXZ_ERROR("not implemented"); + return 0; } template bool RIndex::formatIsTrivial() const { - + return mI->formatIsTrivial(); } template decltype(auto) RIndex::xpr(const Sptr>& _this) const { - + CXZ_ERROR("not implemented"); + return 0; } template - int RIndex::rank() const + SizeT RIndex::rank() const { - + return mK->lex(); } template Sptr RIndex::local() const { - + return mI; } + /*=====================+ + | RRangeFactory | + +=====================*/ + + template + RRangeFactory::RRangeFactory(const Sptr& ri, + const Sptr& rk): + mRI(ri), + mRK(rk) + { + if constexpr(has_static_sub::value and + has_static_sub::value) { + static_assert(typename RangeI::NR == typename RangeK::NR, + "ranges have to be of same dimension"); + } + else { + CXZ_ASSERT(ri->dim() == rk->dim(), "ranges have to be of same dimension, got " + << ri->dim() << " and " << rk->dim()); + } + } + + template + void RRangeFactory::make() + { + Vector key = { mRI->key(), mRK->key() }; + const auto& info = typeid(RRange); + mProd = this->fromCreated(info, key); + if(mProd == nullptr) { + mProd = std::shared_ptr> + ( new RRange(mRI, mRK) ); + this->addToCreated(info, key, mProd); + } + } + } // namespace mpi } // namespace CNORXZ diff --git a/src/opt/mpi/include/rrange.h b/src/opt/mpi/include/rrange.h index 7b0caf0..8503ccf 100644 --- a/src/opt/mpi/include/rrange.h +++ b/src/opt/mpi/include/rrange.h @@ -26,12 +26,12 @@ namespace CNORXZ @tparam IndexK Index type used to indicate the rank. */ template - class RIndex : public IndexInterface,typename Index::MetaType> + class RIndex : public IndexInterface,typename IndexI::MetaType> { public: - typedef IndexInterface,typename Index::MetaType> IB; - typedef typename Index::MetaType MetaType; - typedef RRange RangeType; + typedef IndexInterface,typename IndexI::MetaType> IB; + typedef typename IndexI::MetaType MetaType; + typedef RRange RangeType; INDEX_RANDOM_ACCESS_ITERATOR_DEFS(MetaType); @@ -141,7 +141,7 @@ namespace CNORXZ decltype(auto) xpr(const Sptr>& _this) const; /** Get the current rank. */ - int rank() const; + SizeT rank() const; /** Get the local index on THIS rank. */ Sptr local() const; @@ -149,24 +149,48 @@ namespace CNORXZ private: Sptr mRange; /**< RRange. */ - Sptr mJ; /**< Index on the local range of the THIS rank. */ + Sptr mI; /**< Index on the local range of the THIS rank. */ Sptr mK; /**< Multi-index indicating the current rank. */ //!!! }; - // Factory!!! + // Traits!!! + + /** **** + Specific factory for RRange. + @tparam RangeI Local range type. + @tparam RangeK Geometry range type. + */ + template + class RRangeFactory : public RangeFactoryBase + { + public: + /** Construct and setup factory. + @param ri Local range. + @param rk Geometry range. + */ + RRangeFactory(const Sptr& ri, const Sptr& rk); + + private: + RRangeFactory() = default; + virtual void make() override final; + + Sptr mRI; + Sptr mRK; + }; /** **** Range-Wrapper for ranges that are distributed on MPI ranks. - @tparam Range Local range type. + @tparam RangeI Local range type. + @tparam RangeK Geometry range type. */ - template - class RRange : public RangeInterface> + template + class RRange : public RangeInterface> { public: typedef RangeBase RB; - typedef RIndex IndexType; - typedef typename Range::MetaType MetaType; + typedef RIndex IndexType; + typedef typename RangeI::MetaType MetaType; friend RRangeFactory; @@ -180,10 +204,10 @@ namespace CNORXZ virtual RangePtr extend(const RangePtr& r) const override final; /** Get local range. */ - Sptr local() const; + Sptr local() const; /** Get range of the rank geometry. */ - Sptr geom() const; + Sptr geom() const; /** Get meta data for given lexicographic position. @param pos Lexicographic position. @@ -214,8 +238,8 @@ namespace CNORXZ */ RRange(const Sptr& loc, const Sptr& geom); - Sptr mLocal; /**< Local range of THIS rank. */ - Sptr mGeom; /**< Rank geometry range. */ + Sptr mLocal; /**< Local range of THIS rank. */ + Sptr mGeom; /**< Rank geometry range. */ };