From 9bfb5f4707fc0574c013c4137ae0d2736a554a57 Mon Sep 17 00:00:00 2001 From: Christian Zimmermann Date: Mon, 5 Dec 2022 00:14:00 +0100 Subject: [PATCH] dynamic index multiplication --- src/include/ranges/index_mul.cc.h | 93 +++++++++++++++++++++++-------- src/tests/range_unit_test.cc | 4 ++ 2 files changed, 75 insertions(+), 22 deletions(-) diff --git a/src/include/ranges/index_mul.cc.h b/src/include/ranges/index_mul.cc.h index df5f16e..d7f61b1 100644 --- a/src/include/ranges/index_mul.cc.h +++ b/src/include/ranges/index_mul.cc.h @@ -84,25 +84,48 @@ namespace CNORXZ constexpr decltype(auto) operator*(const IndexInterface& a, const IndexInterface& b) { - // special operations for DIndex / YIndex - constexpr SizeT I1D = index_dim::value; - constexpr SizeT I2D = index_dim::value; - if constexpr(I1D == 1){ - if constexpr(I2D == 1){ - return MIndex(a.THIS(),b.THIS()); + if constexpr(std::is_same::value){ + if constexpr(std::is_same::value){ + return YIndex({ a.THIS().xptr(), b.THIS().xptr() }); } - else { - return MIndexMul::evalXM(a, b.THIS(), std::make_index_sequence{}); + else if constexpr(std::is_same::value){ + auto p = b.THIS().pack(); + p.insert(0, a.THIS().xptr()); + return YIndex(p); + } + } + else if constexpr(std::is_same::value){ + if constexpr(std::is_same::value){ + auto p = a.THIS().pack(); + p.push_back(b.THIS().xptr()); + return YIndex(p); + } + else if constexpr(std::is_same::value){ + auto p = a.THIS().pack(); + p.insert(p.end(), b.THIS().pack().begin(), b.THIS().pack().end()); + return YIndex(p); } } else { - if constexpr(I2D == 1){ - return MIndexMul::evalMX(a.THIS(), b, std::make_index_sequence{}); + constexpr SizeT I1D = index_dim::value; + constexpr SizeT I2D = index_dim::value; + if constexpr(I1D == 1){ + if constexpr(I2D == 1){ + return MIndex(a.THIS(),b.THIS()); + } + else { + return MIndexMul::evalXM(a, b.THIS(), std::make_index_sequence{}); + } } else { - return MIndexMul::evalMM(a.THIS(), b.THIS(), - std::make_index_sequence{}, - std::make_index_sequence{}); + if constexpr(I2D == 1){ + return MIndexMul::evalMX(a.THIS(), b, std::make_index_sequence{}); + } + else { + return MIndexMul::evalMM(a.THIS(), b.THIS(), + std::make_index_sequence{}, + std::make_index_sequence{}); + } } } } @@ -114,21 +137,47 @@ namespace CNORXZ template decltype(auto) iptrMul(const Sptr& a, const Sptr& b) { - // special operations for DIndex / YIndex - if constexpr(index_dim::value == 1){ - if constexpr(index_dim::value == 1){ - return std::make_shared>(a, b); + if constexpr(std::is_same::value){ + if constexpr(std::is_same::value){ + return std::make_shared({ a->xptr(), b->xptr() }); } - else { - return MIndexSptrMul::evalXM(a, b); + else if constexpr(std::is_same::value){ + auto p = b->pack(); + p.insert(0, a->xptr()); + return std::make_shared(p); + } + } + else if constexpr(std::is_same::value){ + if constexpr(std::is_same::value){ + auto p = a->pack(); + p.push_back(b->xptr()); + return std::make_shared(p); + } + else if constexpr(std::is_same::value){ + auto p = a->pack(); + p.insert(p.end(), b->pack().begin(), b->pack().end()); + return std::make_shared(p); } } else { - if constexpr(index_dim::value == 1){ - return MIndexSptrMul::evalMX(a, b); + constexpr SizeT I1D = index_dim::value; + constexpr SizeT I2D = index_dim::value; + if constexpr(I1D == 1){ + if constexpr(index_dim::value == 1){ + return std::make_shared>(a, b); + } + else { + return MIndexSptrMul::evalXM(a, b, std::make_index_sequence{}); + } } else { - return MIndexSptrMul::evalMM(a, b); + if constexpr(index_dim::value == 1){ + return MIndexSptrMul::evalMX(a, b, std::make_index_sequence{}); + } + else { + return MIndexSptrMul::evalMM(a, b, std::make_index_sequence{}, + std::make_index_sequence{}); + } } } } diff --git a/src/tests/range_unit_test.cc b/src/tests/range_unit_test.cc index fe068c9..74127b2 100644 --- a/src/tests/range_unit_test.cc +++ b/src/tests/range_unit_test.cc @@ -430,6 +430,10 @@ namespace for(auto ui = ur->begin(); ui != ur->end(); ++ui){ const SizeT p = ci.lex()*s1 + ui.lex(); EXPECT_EQ((ci*ui).lex(), p); + for(auto ci2 = cr->begin(); ci2 != cr->end(); ++ci2){ + const SizeT p2 = ci.lex()*s1*s2 + ui.lex()*s2 + ci2.lex(); + EXPECT_EQ((ci*ui*ci2).lex(), p2); + } } } }