dynamic index multiplication

This commit is contained in:
Christian Zimmermann 2022-12-05 00:14:00 +01:00
parent 13467b800a
commit 9bfb5f4707
2 changed files with 75 additions and 22 deletions

View file

@ -84,25 +84,48 @@ namespace CNORXZ
constexpr decltype(auto) operator*(const IndexInterface<I1,Meta1>& a, constexpr decltype(auto) operator*(const IndexInterface<I1,Meta1>& a,
const IndexInterface<I2,Meta2>& b) const IndexInterface<I2,Meta2>& b)
{ {
// special operations for DIndex / YIndex if constexpr(std::is_same<I1,DIndex>::value){
constexpr SizeT I1D = index_dim<I1>::value; if constexpr(std::is_same<I2,DIndex>::value){
constexpr SizeT I2D = index_dim<I2>::value; return YIndex({ a.THIS().xptr(), b.THIS().xptr() });
if constexpr(I1D == 1){
if constexpr(I2D == 1){
return MIndex<I1,I2>(a.THIS(),b.THIS());
} }
else { else if constexpr(std::is_same<I2,YIndex>::value){
return MIndexMul::evalXM(a, b.THIS(), std::make_index_sequence<I2D>{}); auto p = b.THIS().pack();
p.insert(0, a.THIS().xptr());
return YIndex(p);
}
}
else if constexpr(std::is_same<I1,YIndex>::value){
if constexpr(std::is_same<I2,DIndex>::value){
auto p = a.THIS().pack();
p.push_back(b.THIS().xptr());
return YIndex(p);
}
else if constexpr(std::is_same<I2,YIndex>::value){
auto p = a.THIS().pack();
p.insert(p.end(), b.THIS().pack().begin(), b.THIS().pack().end());
return YIndex(p);
} }
} }
else { else {
if constexpr(I2D == 1){ constexpr SizeT I1D = index_dim<I1>::value;
return MIndexMul::evalMX(a.THIS(), b, std::make_index_sequence<I1D>{}); constexpr SizeT I2D = index_dim<I2>::value;
if constexpr(I1D == 1){
if constexpr(I2D == 1){
return MIndex<I1,I2>(a.THIS(),b.THIS());
}
else {
return MIndexMul::evalXM(a, b.THIS(), std::make_index_sequence<I2D>{});
}
} }
else { else {
return MIndexMul::evalMM(a.THIS(), b.THIS(), if constexpr(I2D == 1){
std::make_index_sequence<I1D>{}, return MIndexMul::evalMX(a.THIS(), b, std::make_index_sequence<I1D>{});
std::make_index_sequence<I2D>{}); }
else {
return MIndexMul::evalMM(a.THIS(), b.THIS(),
std::make_index_sequence<I1D>{},
std::make_index_sequence<I2D>{});
}
} }
} }
} }
@ -114,21 +137,47 @@ namespace CNORXZ
template <class I1, class I2> template <class I1, class I2>
decltype(auto) iptrMul(const Sptr<I1>& a, const Sptr<I2>& b) decltype(auto) iptrMul(const Sptr<I1>& a, const Sptr<I2>& b)
{ {
// special operations for DIndex / YIndex if constexpr(std::is_same<I1,DIndex>::value){
if constexpr(index_dim<I1>::value == 1){ if constexpr(std::is_same<I2,DIndex>::value){
if constexpr(index_dim<I2>::value == 1){ return std::make_shared<YIndex>({ a->xptr(), b->xptr() });
return std::make_shared<MIndex<I1,I2>>(a, b);
} }
else { else if constexpr(std::is_same<I2,YIndex>::value){
return MIndexSptrMul::evalXM(a, b); auto p = b->pack();
p.insert(0, a->xptr());
return std::make_shared<YIndex>(p);
}
}
else if constexpr(std::is_same<I1,YIndex>::value){
if constexpr(std::is_same<I2,DIndex>::value){
auto p = a->pack();
p.push_back(b->xptr());
return std::make_shared<YIndex>(p);
}
else if constexpr(std::is_same<I2,YIndex>::value){
auto p = a->pack();
p.insert(p.end(), b->pack().begin(), b->pack().end());
return std::make_shared<YIndex>(p);
} }
} }
else { else {
if constexpr(index_dim<I2>::value == 1){ constexpr SizeT I1D = index_dim<I1>::value;
return MIndexSptrMul::evalMX(a, b); constexpr SizeT I2D = index_dim<I2>::value;
if constexpr(I1D == 1){
if constexpr(index_dim<I2>::value == 1){
return std::make_shared<MIndex<I1,I2>>(a, b);
}
else {
return MIndexSptrMul::evalXM(a, b, std::make_index_sequence<I1D>{});
}
} }
else { else {
return MIndexSptrMul::evalMM(a, b); if constexpr(index_dim<I2>::value == 1){
return MIndexSptrMul::evalMX(a, b, std::make_index_sequence<I2D>{});
}
else {
return MIndexSptrMul::evalMM(a, b, std::make_index_sequence<I1D>{},
std::make_index_sequence<I2D>{});
}
} }
} }
} }

View file

@ -430,6 +430,10 @@ namespace
for(auto ui = ur->begin(); ui != ur->end(); ++ui){ for(auto ui = ur->begin(); ui != ur->end(); ++ui){
const SizeT p = ci.lex()*s1 + ui.lex(); const SizeT p = ci.lex()*s1 + ui.lex();
EXPECT_EQ((ci*ui).lex(), p); EXPECT_EQ((ci*ui).lex(), p);
for(auto ci2 = cr->begin(); ci2 != cr->end(); ++ci2){
const SizeT p2 = ci.lex()*s1*s2 + ui.lex()*s2 + ci2.lex();
EXPECT_EQ((ci*ui*ci2).lex(), p2);
}
} }
} }
} }