mpi: completely remove the rank shift when doing global operations + fixes

This commit is contained in:
Christian Zimmermann 2024-10-15 23:30:50 -07:00
parent 0ee2e0fca2
commit 31e892005b
13 changed files with 158 additions and 154 deletions

View file

@ -0,0 +1,70 @@
#ifndef __cxz_acc_xpr_cc_h__
#define __cxz_acc_xpr_cc_h__
#include "acc_xpr.h"
namespace CNORXZ
{
template <SizeT L, class Xpr, class F>
constexpr AccXpr<L,Xpr,F>::AccXpr(SizeT n, const IndexId<L>& id,
const Xpr& xpr, F&& f) :
mN(n),
mId(id),
mXpr(xpr),
mExt(mXpr.rootSteps(mId)),
mF(std::forward<F>(f))
{}
template <SizeT L, class Xpr, class F>
template <class PosT>
inline decltype(auto) AccXpr<L,Xpr,F>::operator()(const PosT& last) const
{
if constexpr(std::is_same<typename std::remove_reference<F>::type,NoF>::value){
const auto pos = last + mExt( UPos(mN) );
mXpr(pos);
return None {};
}
else {
typedef typename
std::remove_reference<decltype(mXpr(last + mExt( UPos(0) )))>::type OutT;
auto o = OutT();
const auto pos = last + mExt( UPos(mN) );
mF(o, mXpr(pos));
return o;
}
}
template <SizeT L, class Xpr, class F>
inline decltype(auto) AccXpr<L,Xpr,F>::operator()() const
{
if constexpr(std::is_same<typename std::remove_reference<F>::type,NoF>::value){
const auto pos = mExt( UPos(mN) );
mXpr(pos);
return None {};
}
else {
typedef typename
std::remove_reference<decltype(mXpr( mExt( UPos(0) )))>::type OutT;
auto o = OutT();
const auto pos = mExt( UPos(mN) );
mF(o, mXpr(pos));
return o;
}
}
template <SizeT L, class Xpr, class F>
template <SizeT I>
inline decltype(auto) AccXpr<L,Xpr,F>::rootSteps(const IndexId<I>& id) const
{
return mXpr.rootSteps(id);
}
template <SizeT L, class Xpr, class F>
constexpr decltype(auto) accxpr(SizeT n, const IndexId<L>& id, const Xpr& xpr, F&& f)
{
return AccXpr<L,Xpr,F>(n, id, xpr, std::forward<F>(f));
}
}
#endif

41
src/include/xpr/acc_xpr.h Normal file
View file

@ -0,0 +1,41 @@
// rank access expression, fix rank position to current rank
#ifndef __cxz_acc_xpr_h__
#define __cxz_acc_xpr_h__
//#include "base/base.h"
#include "xpr_base.h"
namespace CNORXZ
{
template <SizeT L, class Xpr, class F = NoF>
class AccXpr : public XprInterface<AccXpr<L,Xpr,F>>
{
public:
DEFAULT_MEMBERS(AccXpr);
constexpr AccXpr(SizeT n, const IndexId<L>& id, const Xpr& xpr, F&& f);
template <class PosT>
inline decltype(auto) operator()(const PosT& last) const;
inline decltype(auto) operator()() const;
template <SizeT I>
inline decltype(auto) rootSteps(const IndexId<I>& id) const;
private:
SizeT mN = 0;
IndexId<L> mId;
Xpr mXpr;
typedef decltype(mXpr.rootSteps(mId)) XPosT;
XPosT mExt;
F mF;
};
template <SizeT L, class Xpr, class F = NoF>
constexpr decltype(auto) accxpr(SizeT n, const IndexId<L>& id, const Xpr& xpr, F&& f);
}
#endif

View file

@ -1,73 +0,0 @@
#ifndef __cxz_racc_xpr_cc_h__
#define __cxz_racc_xpr_cc_h__
#include "racc_xpr.h"
namespace CNOXRZ
{
namespace mpi
{
template <SizeT L, class Xpr, class F = NoF>
constexpr AccXpr<L,Xpr,F>::AccXpr(SizeT n, const IndexId<L>& id,
const Xpr& xpr, F&& f) :
mN(n),
mId(id),
mXpr(xpr),
mExt(mXpr.rootSteps(mId)),
mF(std::forward<F>(f))
{}
template <SizeT L, class Xpr, class F = NoF>
template <class PosT>
inline decltype(auto) AccXpr<L,Xpr,F>::operator()(const PosT& last) const
{
if constexpr(std::is_same<typename std::remove_reference<F>::type,NoF>::value){
const auto pos = last + mExt( UPos(mN) );
mXpr(pos);
return None {};
}
else {
typedef typename
std::remove_reference<decltype(mXpr(last + mExt( UPos(0) )))>::type OutT;
auto o = OutT();
const auto pos = last + mExt( UPos(mN) );
mF(o, mXpr(pos));
return o;
}
}
template <SizeT L, class Xpr, class F = NoF>
inline decltype(auto) AccXpr<L,Xpr,F>::operator()() const
{
if constexpr(std::is_same<typename std::remove_reference<F>::type,NoF>::value){
const auto pos = mExt( UPos(mN) );
mXpr(pos);
return None {};
}
else {
typedef typename
std::remove_reference<decltype(mXpr(last + mExt( UPos(0) )))>::type OutT;
auto o = OutT();
const auto pos = mExt( UPos(mN) );
mF(o, mXpr(pos));
return o;
}
}
template <SizeT L, class Xpr, class F = NoF>
template <SizeT I>
inline decltype(auto) AccXpr<L,Xpr,F>::rootSteps(const IndexId<I>& id) const
{
return mXpr.rootSteps(id);
}
template <SizeT L, class Xpr, class F = NoF>
constexpr decltype(auto) accxpr(SizeT n, const IndexId<L>& id, const Xpr& xpr, F&& f)
{
return AccXpr<L,Xpr,F>(size, id, xpr, std::forward<F>(f));
}
}
}
#endif

View file

@ -1,44 +0,0 @@
// rank access expression, fix rank position to current rank
#ifndef __cxz_racc_xpr_h__
#define __cxz_racc_xpr_h__
#include "mpi_base.h"
namespace CNORXZ
{
namespace mpi
{
template <SizeT L, class Xpr, class F = NoF>
class AccXpr : public XprInterface<AccXpr<L,Xpr,F>>
{
public:
DEFAULT_MEMBERS(AccXpr);
constexpr AccXpr(SizeT n, const IndexId<L>& id, const Xpr& xpr, F&& f);
template <class PosT>
inline decltype(auto) operator()(const PosT& last) const;
inline decltype(auto) operator()() const;
template <SizeT I>
inline decltype(auto) rootSteps(const IndexId<I>& id) const;
private:
SizeT mN = 0;
IndexId<L> mId;
Xpr mXpr;
typedef decltype(mXpr.rootSteps(mId)) XPosT;
XPosT mExt;
F mF;
};
template <SizeT L, class Xpr, class F = NoF>
constexpr decltype(auto) accxpr(SizeT n, const IndexId<L>& id, const Xpr& xpr, F&& f);
}
}
#endif

View file

@ -16,3 +16,4 @@
#include "index_id.cc.h" #include "index_id.cc.h"
#include "func.cc.h" #include "func.cc.h"
#include "map_xpr.cc.h" #include "map_xpr.cc.h"
#include "acc_xpr.cc.h"

View file

@ -16,5 +16,6 @@
#include "index_id.h" #include "index_id.h"
#include "func.h" #include "func.h"
#include "map_xpr.h" #include "map_xpr.h"
#include "acc_xpr.h"
#include "xpr.cc.h" #include "xpr.cc.h"

View file

@ -169,6 +169,7 @@ namespace CNORXZ
inline decltype(auto) RCArray<T>::operator()(const DPack& pack) const inline decltype(auto) RCArray<T>::operator()(const DPack& pack) const
{ {
// TODO: assert that none of the indices is rank index // TODO: assert that none of the indices is rank index
CXZ_ERROR("not implemented");
return (*mA)(pack); return (*mA)(pack);
} }
@ -361,6 +362,7 @@ namespace CNORXZ
inline decltype(auto) RArray<T>::operator()(const DPack& pack) const inline decltype(auto) RArray<T>::operator()(const DPack& pack) const
{ {
// TODO: assert that none of the indices is rank index // TODO: assert that none of the indices is rank index
CXZ_ERROR("not implemented");
return (*mB)(pack); return (*mB)(pack);
} }
@ -507,14 +509,13 @@ namespace CNORXZ
} }
// Third loop: Assign map to target buffer positions: // Third loop: Assign map to target buffer positions:
const SizeT myrankoff = myrank*locsz;
assert(mapsize == Nranks*locsz); assert(mapsize == Nranks*locsz);
Vector<SizeT> cnt(Nranks); Vector<SizeT> cnt(Nranks);
mi->ifor( operation mi->ifor( operation
( [&](SizeT p) { ( [&](SizeT p) {
const SizeT r = p / locsz; const SizeT r = p / locsz;
const SizeT l = p % locsz; const SizeT l = p % locsz;
const SizeT mpidx = (p - myrankoff + mapsize) % mapsize; const SizeT mpidx = p;
if(myrank != r and required[p]){ if(myrank != r and required[p]){
SizeT off = 0; SizeT off = 0;
for(SizeT s = 0; s != r; ++s){ for(SizeT s = 0; s != r; ++s){
@ -524,7 +525,7 @@ namespace CNORXZ
++cnt[r]; ++cnt[r];
} }
if(myrank == r){ if(myrank == r){
assert(mpidx < locsz); assert(mpidx < (myrank+1)*locsz);
map[mpidx] = data.data() + l*blocks; map[mpidx] = data.data() + l*blocks;
} }
} , posop(mi) ), NoF {} )(); } , posop(mi) ), NoF {} )();

View file

@ -23,23 +23,26 @@ namespace CNORXZ
const Sptr<SrcIndex>& si, const Sptr<SrcIndex>& si,
const F& f, const Sptr<Vector<SizeT>>& m) const F& f, const Sptr<Vector<SizeT>>& m)
{ {
// This was the old shift, keep it here as comment if we want to introduce other shifts
// in order to reduce memory consumption by the maps;
// remember to invert the shift in the map xpr BEFORE calling the map!
//const SizeT locsz = tix.local()->pmax().val();
//const SizeT tarsize = locsz*mpi::getNumRanks();
//const SizeT idx = (tix.pos() - locsz*myrank + tarsize) % tarsize;
auto six = *si; auto six = *si;
auto sie = si->range()->end(); auto sie = si->range()->end();
auto tix = *ti; auto tix = *ti;
const SizeT locsz = tix.local()->pmax().val();
const SizeT tarsize = locsz*mpi::getNumRanks();
const SizeT mapsize = m->size(); const SizeT mapsize = m->size();
const SizeT myrank = mpi::getRankNumber(); const SizeT myrank = mpi::getRankNumber();
if constexpr(mpi::is_rank_index<SrcIndex>::value){ if constexpr(mpi::is_rank_index<SrcIndex>::value){
CXZ_ASSERT(mapsize == six.local()->pmax().val(), "map not well-formatted: size = " CXZ_ASSERT(mapsize == six.pmax().val(), "map not well-formatted: size = "
<< mapsize << ", expected " << six.local()->pmax().val()); << mapsize << ", expected " << six.local()->pmax().val());
for(six = 0; six != sie; ++six){ for(six = 0; six != sie; ++six){
tix.at( f(*six) ); tix.at( f(*six) );
if(six.rank() == myrank){ if(six.rank() == myrank){
//const SizeT idx = (tix.pos() - locsz*tix.rank() + tarsize) % tarsize; const SizeT idx = tix.pos();
const SizeT idx = (tix.pos() - locsz*myrank + tarsize) % tarsize; (*m)[six.pos()] = idx;
//const SizeT idx = tix.pos();
(*m)[six.local()->pos()] = idx;
} }
} }
} }
@ -48,9 +51,7 @@ namespace CNORXZ
<< mapsize << ", expected " << six.pmax().val()); << mapsize << ", expected " << six.pmax().val());
for(six = 0; six != sie; ++six){ for(six = 0; six != sie; ++six){
tix.at( f(*six) ); tix.at( f(*six) );
//const SizeT idx = (tix.pos() - locsz*tix.rank() + tarsize) % tarsize; const SizeT idx = tix.pos()
const SizeT idx = (tix.pos() - locsz*myrank + tarsize) % tarsize;
//const SizeT idx = tix.pos()
(*m)[six.pos()] = idx; (*m)[six.pos()] = idx;
} }
} }
@ -62,13 +63,7 @@ namespace CNORXZ
const Sptr<SrcIndex>& si, const Sptr<SrcIndex>& si,
const F& f) const F& f)
{ {
SizeT mapsize = 0; const SizeT mapsize = si->pmax().val();
if constexpr(mpi::is_rank_index<SrcIndex>::value){
mapsize = si->local()->lmax().val();
}
else {
mapsize = si->lmax().val();
}
auto o = std::make_shared<Vector<SizeT>>(mapsize); auto o = std::make_shared<Vector<SizeT>>(mapsize);
setup(ti,si,f,o); setup(ti,si,f,o);
return o; return o;

View file

@ -29,12 +29,16 @@ namespace CNORXZ
template <class PosT> template <class PosT>
constexpr decltype(auto) CROpRoot<T,RIndexT,IndexT>::operator()(const PosT& pos) const constexpr decltype(auto) CROpRoot<T,RIndexT,IndexT>::operator()(const PosT& pos) const
{ {
//CXZ_ASSERT(pos.val() < mRIndex->pmax().val(), pos.val() << ">=" << mRIndex->pmax().val());
//CXZ_ASSERT(mData[pos.val()] != nullptr, "data[" << pos.val() << "] == null");
//CXZ_ASSERT(pos.next().val() < mIndex->pmax().val(), pos.val() << ">=" << mIndex->pmax().val());
return (mData[pos.val()])[pos.next().val()]; return (mData[pos.val()])[pos.next().val()];
} }
template <typename T, class RIndexT, class IndexT> template <typename T, class RIndexT, class IndexT>
constexpr decltype(auto) CROpRoot<T,RIndexT,IndexT>::operator()() const constexpr decltype(auto) CROpRoot<T,RIndexT,IndexT>::operator()() const
{ {
//CXZ_ASSERT(mData[0] != nullptr, "data[" << 0 << "] == null");
return (mData[0])[0]; return (mData[0])[0];
} }
@ -61,11 +65,9 @@ namespace CNORXZ
const Sptr<IndexT>& li) : const Sptr<IndexT>& li) :
mLocal(&a.local()), mLocal(&a.local()),
mData(a.buffermap().data()), mData(a.buffermap().data()),
//mData(a.data()),
mRIndex(ri), mRIndex(ri),
mIndex(li) mIndex(li)
{ {
//CXZ_ERROR("nope");
CXZ_ASSERT(a.buffermap().size() == ri->lmax().val(), CXZ_ASSERT(a.buffermap().size() == ri->lmax().val(),
"data map not properly initialized: map size = " << a.buffermap().size() "data map not properly initialized: map size = " << a.buffermap().size()
<< ", rank index range size = " << ri->lmax().val()); << ", rank index range size = " << ri->lmax().val());
@ -75,8 +77,8 @@ namespace CNORXZ
template <class Op> template <class Op>
constexpr ROpRoot<T,RIndexT,IndexT>& ROpRoot<T,RIndexT,IndexT>::operator=(const Op& in) constexpr ROpRoot<T,RIndexT,IndexT>& ROpRoot<T,RIndexT,IndexT>::operator=(const Op& in)
{ {
(*mLocal)(mindexPtr(mRIndex->local()*mIndex)) = in; (*mLocal)(mindexPtr(mRIndex->local()*mIndex)).a
//OI::a(mIndex, [](auto& a, const auto& b) { a = b; }, in); (mindexPtr(mRIndex*mIndex),[](auto& a, const auto& b) { a = b; }, in);
return *this; return *this;
} }
@ -84,16 +86,16 @@ namespace CNORXZ
template <class Op> template <class Op>
constexpr ROpRoot<T,RIndexT,IndexT>& ROpRoot<T,RIndexT,IndexT>::operator+=(const Op& in) constexpr ROpRoot<T,RIndexT,IndexT>& ROpRoot<T,RIndexT,IndexT>::operator+=(const Op& in)
{ {
(*mLocal)(mindexPtr(mRIndex->local()*mIndex)) += in; (*mLocal)(mindexPtr(mRIndex->local()*mIndex)).a
//OI::a(mIndex, [](auto& a, const auto& b) { a += b; }, in); (mindexPtr(mRIndex*mIndex),[](auto& a, const auto& b) { a += b; }, in);
return *this; return *this;
} }
template <typename T, class RIndexT, class IndexT> template <typename T, class RIndexT, class IndexT>
constexpr ROpRoot<T,RIndexT,IndexT>& ROpRoot<T,RIndexT,IndexT>::operator=(const ROpRoot& in) constexpr ROpRoot<T,RIndexT,IndexT>& ROpRoot<T,RIndexT,IndexT>::operator=(const ROpRoot& in)
{ {
(*mLocal)(mindexPtr(mRIndex->local()*mIndex)) = in; (*mLocal)(mindexPtr(mRIndex->local()*mIndex)).a
//OI::a(mIndex, [](auto& a, const auto& b) { a = b; }, in); (mindexPtr(mRIndex*mIndex),[](auto& a, const auto& b) { a = b; }, in);
return *this; return *this;
} }
@ -101,12 +103,16 @@ namespace CNORXZ
template <class PosT> template <class PosT>
constexpr decltype(auto) ROpRoot<T,RIndexT,IndexT>::operator()(const PosT& pos) const constexpr decltype(auto) ROpRoot<T,RIndexT,IndexT>::operator()(const PosT& pos) const
{ {
//CXZ_ASSERT(pos.val() < mRIndex->pmax().val(), pos.val() << ">=" << mRIndex->pmax().val());
//CXZ_ASSERT(mData[pos.val()] != nullptr, "data[" << pos.val() << "] == null");
//CXZ_ASSERT(pos.next().val() < mIndex->pmax().val(), pos.val() << ">=" << mIndex->pmax().val());
return (mData[pos.val()])[pos.next().val()]; return (mData[pos.val()])[pos.next().val()];
} }
template <typename T, class RIndexT, class IndexT> template <typename T, class RIndexT, class IndexT>
constexpr decltype(auto) ROpRoot<T,RIndexT,IndexT>::operator()() const constexpr decltype(auto) ROpRoot<T,RIndexT,IndexT>::operator()() const
{ {
//CXZ_ASSERT(mData[0] != nullptr, "data[" << 0 << "] == null");
return (mData[0])[0]; return (mData[0])[0];
} }

View file

@ -116,6 +116,12 @@ namespace CNORXZ
static constexpr SizeT value = 2; static constexpr SizeT value = 2;
}; };
template <typename T, class RIndexT, class IndexT>
struct op_size<mpi::ROpRoot<T,RIndexT,IndexT>>
{
static constexpr SizeT value = 2;
};
} // namespace CNORXZ } // namespace CNORXZ
#endif #endif

View file

@ -231,8 +231,7 @@ namespace CNORXZ
if constexpr(I != 0){ return SPos<0> {}; } if constexpr(I != 0){ return SPos<0> {}; }
else { return UPos(id == this->id() ? 1 : 0); } else { return UPos(id == this->id() ? 1 : 0); }
}; };
return mI->stepSize(id) + own(); return mK->stepSize(id) * mI->pmax() * UPos(mRankFormat) + mI->stepSize(id) + own();
//return getRankStepSize(id);
} }
template <class IndexI, class IndexK> template <class IndexI, class IndexK>
@ -330,7 +329,8 @@ namespace CNORXZ
template <class Xpr, class F> template <class Xpr, class F>
constexpr decltype(auto) RIndex<IndexI,IndexK>::ifor(const Xpr& xpr, F&& f) const constexpr decltype(auto) RIndex<IndexI,IndexK>::ifor(const Xpr& xpr, F&& f) const
{ {
return mI->ifor(xpr, std::forward<F>(f)); return accxpr( mpi::getRankNumber(), mK->id(), mI->ifor(xpr, std::forward<F>(f)),
NoF {});
} }
template <class IndexI, class IndexK> template <class IndexI, class IndexK>

View file

@ -138,11 +138,10 @@ namespace
(std::get<2>(vec)+1)%L, (std::get<3>(vec)+1)%L); } ); (std::get<2>(vec)+1)%L, (std::get<3>(vec)+1)%L); } );
Vector<bool> req(xp->range()->size(), false); Vector<bool> req(xp->range()->size(), false);
for(const auto& r: *imap1){ for(const auto& r: *imap1){
req[(r+mpi::getRankNumber()*16*12*12*12/4)%req.size()] = true; req[r] = true;
} }
res.load(x, AB, req); // DUMMY, not used... res.load(x, AB, req); // DUMMY, not used...
mM1.load(xp, AB, req); mM1.load(xp, AB, req);
//res.rop(x*A*B) = mapXpr(xp,x,imap1, mM1(xp*A*B) - mM1(x*A*B) );
res(x*A*B) = mapXpr(xp,x,imap1, mM1(xp*A*B) - mM1(x*A*B) ); res(x*A*B) = mapXpr(xp,x,imap1, mM1(xp*A*B) - mM1(x*A*B) );
for(SizeT x0 = 0; x0 != T; ++x0) { for(SizeT x0 = 0; x0 != T; ++x0) {

View file

@ -119,9 +119,9 @@ namespace
setupBuffer(rgi, req, data, buf, map, mSRange->size()); setupBuffer(rgi, req, data, buf, map, mSRange->size());
EXPECT_EQ(mRRange->sub(1)->size(), 16*12*12*12/4); EXPECT_EQ(mRRange->sub(1)->size(), 16*12*12*12/4);
const SizeT locsz = rgj->local()->lmax().val(); //const SizeT locsz = rgj->local()->lmax().val();
const SizeT myrankoff = myrank*locsz; //const SizeT myrankoff = myrank*locsz;
const SizeT mapsize = map.size(); //const SizeT mapsize = map.size();
// Fourth loop: Check: // Fourth loop: Check:
for(*rgi = 0, gi = 0; rgi->lex() != rgi->lmax().val(); ++*rgi, ++gi){ for(*rgi = 0, gi = 0; rgi->lex() != rgi->lmax().val(); ++*rgi, ++gi){
gj = gi.lex(); gj = gi.lex();
@ -132,7 +132,8 @@ namespace
*rgj = gj.lex(); *rgj = gj.lex();
if(rgi->rank() == myrank){ if(rgi->rank() == myrank){
const SizeT mpidx = (rgj->pos() - myrankoff + mapsize) % mapsize; const SizeT mpidx = rgj->pos();
//const SizeT mpidx = (rgj->pos() - myrankoff + mapsize) % mapsize;
EXPECT_TRUE(map.data()[mpidx] != nullptr); EXPECT_TRUE(map.data()[mpidx] != nullptr);
const Double vn = *map[mpidx]/blocks; const Double vn = *map[mpidx]/blocks;