diff --git a/src/opt/mpi/include/rarray.cc.h b/src/opt/mpi/include/rarray.cc.h index d404198..d5bfe2a 100644 --- a/src/opt/mpi/include/rarray.cc.h +++ b/src/opt/mpi/include/rarray.cc.h @@ -393,8 +393,8 @@ namespace CNORXZ | non-member functions | +============================*/ - template - void setupBuffer(const Sptr>& rgi, const Sptr>& rgj, + template + void setupBuffer(const Sptr& lpi, const Sptr>& rgj, const Sptr>& imap, const CArrayBase& data, Vector& buf, Vector& map, const SizeT blocks) { @@ -408,18 +408,32 @@ namespace CNORXZ sb.reserve(data.size()); } Vector> request(Nranks); - const SizeT locsz = rgi->local()->lmax().val(); + const SizeT locsz = rgj->local()->lmax().val(); - // First loop: setup send buffer - rgi->ifor( mapXpr(rgj, rgi, imap, + // ==== new ==== + + // First loop: Find out what's needed + Vector required(rgj->lmax().val(),false); + lpi->ifor( mapXpr(rgj, lpi, imap, operation - ( [&](SizeT p, SizeT q) { - const SizeT r = p / locsz; + ( [&](SizeT j) { + const SizeT r = j / locsz; if(myrank != r){ - request[r].push_back(p % locsz); + required[j] = true; } - } , posop(rgj), posop(rgi) ) ) , + } , posop(rgj) ) ), NoF {} )(); + + // Second loop: setup send buffer + auto mi = mindexPtr(rgj->rankI(), rgj->local()); + mi->ifor( operation + ( [&](SizeT p) { + const SizeT r = p / locsz; + if(myrank != r and required[p]){ + request[r].push_back(p % locsz); + } + } , posop(mi) ) , + NoF {} )(); // transfer: Vector reqsizes(Nranks); @@ -470,22 +484,24 @@ namespace CNORXZ } - // Second loop: Assign map to target buffer positions: + // Third loop: Assign map to target buffer positions: Vector cnt(Nranks); - rgi->ifor( mapXpr(rgj, rgi, imap, - operation - ( [&](SizeT p, SizeT q) { - const SizeT r = p / locsz; - if(myrank != r){ - SizeT off = 0; - for(SizeT s = 0; s != r; ++s){ - off += ext[myrank][s]; - } - map[p] = buf.data() + off*blocks + cnt[r]*blocks; - ++cnt[r]; - } - map[q + myrank*locsz] = data.data() + q*blocks; - } , posop(rgj), posop(rgi) ) ), NoF {} )(); + mi->ifor( operation + ( [&](SizeT p) { + const SizeT r = p / locsz; + const SizeT l = p % locsz; + if(myrank != r and required[p]){ + SizeT off = 0; + for(SizeT s = 0; s != r; ++s){ + off += ext[myrank][s]; + } + map[p] = buf.data() + off*blocks + cnt[r]*blocks; + ++cnt[r]; + } + if(myrank == r){ + map[p] = data.data() + l*blocks; + } + } , posop(mi) ), NoF {} )(); } diff --git a/src/opt/mpi/include/rarray.h b/src/opt/mpi/include/rarray.h index 0f2013e..2fc7203 100644 --- a/src/opt/mpi/include/rarray.h +++ b/src/opt/mpi/include/rarray.h @@ -244,8 +244,8 @@ namespace CNORXZ }; - template - void setupBuffer(const Sptr>& rgj, const Sptr>& rgi, + template + void setupBuffer(const Sptr& lpi, const Sptr>& rgj, const Sptr>& imap, const CArrayBase& data, Vector& buf, Vector& map, const SizeT blocks); diff --git a/src/opt/mpi/include/rrange.cc.h b/src/opt/mpi/include/rrange.cc.h index 277a471..878034a 100644 --- a/src/opt/mpi/include/rrange.cc.h +++ b/src/opt/mpi/include/rrange.cc.h @@ -376,6 +376,11 @@ namespace CNORXZ return mI; } + template + Sptr RIndex::rankI() const + { + return mK; + } /*=====================+ | RRangeFactory | diff --git a/src/opt/mpi/include/rrange.h b/src/opt/mpi/include/rrange.h index df92ef4..f52e40f 100644 --- a/src/opt/mpi/include/rrange.h +++ b/src/opt/mpi/include/rrange.h @@ -155,7 +155,10 @@ namespace CNORXZ /** Get the local index on THIS rank. */ Sptr local() const; - + + /** Get index indicating the current rank this index points to. */ + Sptr rankI() const; + private: SizeT mLex = 0; Sptr mRange; /**< RRange. */ diff --git a/src/opt/mpi/tests/setbuf_unit_test.cc b/src/opt/mpi/tests/setbuf_unit_test.cc index 3eafab1..b8fc327 100644 --- a/src/opt/mpi/tests/setbuf_unit_test.cc +++ b/src/opt/mpi/tests/setbuf_unit_test.cc @@ -114,7 +114,7 @@ namespace setupBuffer(rgj, rgi, fmap, data, buf, map, mSRange->size()); EXPECT_EQ(mRRange->sub(1)->size(), 16*12*12*12/4); - // Third loop: Check: + // Fourth loop: Check: for(*rgi = 0, gi = 0; rgi->lex() != rgi->lmax().val(); ++*rgi, ++gi){ gj = gi.lex(); *gj.pack()[C0] = (gj.pack()[C0]->lex() + 1) % gj.pack()[C0]->lmax().val();