From bd8970ae33baef51c56391e0a555ac69b8c73f6f Mon Sep 17 00:00:00 2001 From: Christian Zimmermann Date: Sat, 16 Mar 2024 17:25:37 +0100 Subject: [PATCH] URange cast for PRanges + fix mpi test --- src/include/ranges/urange.cc.h | 5 ++++- src/opt/mpi/lib/rrange.cc | 6 ++++-- src/opt/mpi/tests/rrange_unit_test.cc | 11 +++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/include/ranges/urange.cc.h b/src/include/ranges/urange.cc.h index 7b868b5..9e93c0c 100644 --- a/src/include/ranges/urange.cc.h +++ b/src/include/ranges/urange.cc.h @@ -359,7 +359,10 @@ namespace CNORXZ template static inline Sptr> transform(const RangePtr& r) { - if(r->type() == typeid(URange)){ + if(r->type() == typeid(PRange>)){ + return transform( std::dynamic_pointer_cast>>(r)->derive() ); + } + else if(r->type() == typeid(URange)){ auto rr = std::dynamic_pointer_cast>(r); Vector v(rr->size()); std::transform(rr->begin(), rr->end(), v.begin(), diff --git a/src/opt/mpi/lib/rrange.cc b/src/opt/mpi/lib/rrange.cc index 31f2c99..b3e3976 100644 --- a/src/opt/mpi/lib/rrange.cc +++ b/src/opt/mpi/lib/rrange.cc @@ -51,7 +51,9 @@ namespace CNORXZ auto jb = global->begin(); auto je = global->begin(); MArray o(geom); - o(k) = operation( [&](const SizeT x){ jb = n*x; je = n*(x+1)-1; return jb.prange(je); } , xpr(k) ); + o(k) = operation( [&](const SizeT x){ + jb = n*x; je = n*(x+1)-1; return jb.prange(je); + } , xpr(k) ); return o; } } @@ -66,7 +68,7 @@ namespace CNORXZ } } assert(o); - auto loc = rangeCast(global); + auto loc = rangeCast(o); auto geo = rangeCast(geom); RRangeFactory xx(loc, geo); return RRangeFactory(loc, geo).create(); diff --git a/src/opt/mpi/tests/rrange_unit_test.cc b/src/opt/mpi/tests/rrange_unit_test.cc index 9c470e4..dc134b8 100644 --- a/src/opt/mpi/tests/rrange_unit_test.cc +++ b/src/opt/mpi/tests/rrange_unit_test.cc @@ -34,6 +34,14 @@ namespace CXZ_ASSERT(getNumRanks() == 4, "exptected 4 ranks"); Vector xs(12); Vector ts(16); + for(SizeT i = 0; i != xs.size(); ++i){ + const Int x = static_cast(i) - static_cast(xs.size()/2); + xs[i] = x; + } + for(SizeT i = 0; i != ts.size(); ++i){ + const Int t = static_cast(i) - static_cast(ts.size()/2); + ts[i] = t; + } mXRange = URangeFactory(xs).create(); mTRange = URangeFactory(ts).create(); Vector rs { mTRange, mXRange, mXRange, mXRange }; @@ -55,6 +63,9 @@ namespace TEST_F(RRange_Test, Basics) { EXPECT_EQ(mRRange->size(), mGRange->size()); + typedef UIndex UI; + MIndex mi(mRRange->sub(1)); + EXPECT_EQ(mi.lmax().val(), mGRange->size()/mRRange->sub(0)->size()); } }