mpi rarray: improve setupBuffer

This commit is contained in:
Christian Zimmermann 2024-04-28 01:29:37 +02:00
parent 1d77098670
commit af32689c00
5 changed files with 52 additions and 28 deletions

View file

@ -393,8 +393,8 @@ namespace CNORXZ
| non-member functions | | non-member functions |
+============================*/ +============================*/
template <class TarI, class RTarI, class SrcI, class RSrcI, typename T> template <class LoopI, class SrcI, class RSrcI, typename T>
void setupBuffer(const Sptr<RIndex<TarI,RTarI>>& rgi, const Sptr<RIndex<SrcI,RSrcI>>& rgj, void setupBuffer(const Sptr<LoopI>& lpi, const Sptr<RIndex<SrcI,RSrcI>>& rgj,
const Sptr<Vector<SizeT>>& imap, const CArrayBase<T>& data, const Sptr<Vector<SizeT>>& imap, const CArrayBase<T>& data,
Vector<T>& buf, Vector<const T*>& map, const SizeT blocks) Vector<T>& buf, Vector<const T*>& map, const SizeT blocks)
{ {
@ -408,19 +408,33 @@ namespace CNORXZ
sb.reserve(data.size()); sb.reserve(data.size());
} }
Vector<Vector<SizeT>> request(Nranks); Vector<Vector<SizeT>> request(Nranks);
const SizeT locsz = rgi->local()->lmax().val(); const SizeT locsz = rgj->local()->lmax().val();
// First loop: setup send buffer // ==== new ====
rgi->ifor( mapXpr(rgj, rgi, imap,
// First loop: Find out what's needed
Vector<bool> required(rgj->lmax().val(),false);
lpi->ifor( mapXpr(rgj, lpi, imap,
operation operation
( [&](SizeT p, SizeT q) { ( [&](SizeT j) {
const SizeT r = p / locsz; const SizeT r = j / locsz;
if(myrank != r){ if(myrank != r){
request[r].push_back(p % locsz); required[j] = true;
} }
} , posop(rgj), posop(rgi) ) ) , } , posop(rgj) ) ),
NoF {} )(); NoF {} )();
// Second loop: setup send buffer
auto mi = mindexPtr(rgj->rankI(), rgj->local());
mi->ifor( operation
( [&](SizeT p) {
const SizeT r = p / locsz;
if(myrank != r and required[p]){
request[r].push_back(p % locsz);
}
} , posop(mi) ) ,
NoF {} )();
// transfer: // transfer:
Vector<SizeT> reqsizes(Nranks); Vector<SizeT> reqsizes(Nranks);
SizeT bufsize = 0; SizeT bufsize = 0;
@ -470,22 +484,24 @@ namespace CNORXZ
} }
// Second loop: Assign map to target buffer positions: // Third loop: Assign map to target buffer positions:
Vector<SizeT> cnt(Nranks); Vector<SizeT> cnt(Nranks);
rgi->ifor( mapXpr(rgj, rgi, imap, mi->ifor( operation
operation ( [&](SizeT p) {
( [&](SizeT p, SizeT q) { const SizeT r = p / locsz;
const SizeT r = p / locsz; const SizeT l = p % locsz;
if(myrank != r){ if(myrank != r and required[p]){
SizeT off = 0; SizeT off = 0;
for(SizeT s = 0; s != r; ++s){ for(SizeT s = 0; s != r; ++s){
off += ext[myrank][s]; off += ext[myrank][s];
} }
map[p] = buf.data() + off*blocks + cnt[r]*blocks; map[p] = buf.data() + off*blocks + cnt[r]*blocks;
++cnt[r]; ++cnt[r];
} }
map[q + myrank*locsz] = data.data() + q*blocks; if(myrank == r){
} , posop(rgj), posop(rgi) ) ), NoF {} )(); map[p] = data.data() + l*blocks;
}
} , posop(mi) ), NoF {} )();
} }

View file

@ -244,8 +244,8 @@ namespace CNORXZ
}; };
template <class TarI, class RTarI, class SrcI, class RSrcI, typename T> template <class LoopI, class SrcI, class RSrcI, typename T>
void setupBuffer(const Sptr<RIndex<TarI,RTarI>>& rgj, const Sptr<RIndex<SrcI,RSrcI>>& rgi, void setupBuffer(const Sptr<LoopI>& lpi, const Sptr<RIndex<SrcI,RSrcI>>& rgj,
const Sptr<Vector<SizeT>>& imap, const CArrayBase<T>& data, const Sptr<Vector<SizeT>>& imap, const CArrayBase<T>& data,
Vector<T>& buf, Vector<const T*>& map, const SizeT blocks); Vector<T>& buf, Vector<const T*>& map, const SizeT blocks);

View file

@ -376,6 +376,11 @@ namespace CNORXZ
return mI; return mI;
} }
template <class IndexI, class IndexK>
Sptr<IndexK> RIndex<IndexI,IndexK>::rankI() const
{
return mK;
}
/*=====================+ /*=====================+
| RRangeFactory | | RRangeFactory |

View file

@ -156,6 +156,9 @@ namespace CNORXZ
/** Get the local index on THIS rank. */ /** Get the local index on THIS rank. */
Sptr<IndexI> local() const; Sptr<IndexI> local() const;
/** Get index indicating the current rank this index points to. */
Sptr<IndexK> rankI() const;
private: private:
SizeT mLex = 0; SizeT mLex = 0;
Sptr<RangeType> mRange; /**< RRange. */ Sptr<RangeType> mRange; /**< RRange. */

View file

@ -114,7 +114,7 @@ namespace
setupBuffer(rgj, rgi, fmap, data, buf, map, mSRange->size()); setupBuffer(rgj, rgi, fmap, data, buf, map, mSRange->size());
EXPECT_EQ(mRRange->sub(1)->size(), 16*12*12*12/4); EXPECT_EQ(mRRange->sub(1)->size(), 16*12*12*12/4);
// Third loop: Check: // Fourth loop: Check:
for(*rgi = 0, gi = 0; rgi->lex() != rgi->lmax().val(); ++*rgi, ++gi){ for(*rgi = 0, gi = 0; rgi->lex() != rgi->lmax().val(); ++*rgi, ++gi){
gj = gi.lex(); gj = gi.lex();
*gj.pack()[C0] = (gj.pack()[C0]->lex() + 1) % gj.pack()[C0]->lmax().val(); *gj.pack()[C0] = (gj.pack()[C0]->lex() + 1) % gj.pack()[C0]->lmax().val();