diff --git a/src/include/ranges/crange.cc.h b/src/include/ranges/crange.cc.h index c078f22..1c603d4 100644 --- a/src/include/ranges/crange.cc.h +++ b/src/include/ranges/crange.cc.h @@ -3,6 +3,7 @@ #define __cxz_crange_cc_h__ #include "crange.h" +#include "index_mul.h" namespace CNORXZ { @@ -11,6 +12,12 @@ namespace CNORXZ { return For<0,Xpr,F>(this->pmax().val(), this->id(), xpr, std::forward(f)); } + + template + decltype(auto) operator*(const Sptr& a, const Sptr& b) + { + return iptrMul(a, b); + } } #endif diff --git a/src/include/ranges/crange.h b/src/include/ranges/crange.h index b74cf6a..33baa36 100644 --- a/src/include/ranges/crange.h +++ b/src/include/ranges/crange.h @@ -49,6 +49,8 @@ namespace CNORXZ Sptr mRangePtr; }; + template + decltype(auto) operator*(const Sptr& a, const Sptr& b); class CRangeFactory : public RangeFactoryBase { diff --git a/src/include/ranges/index_mul.cc.h b/src/include/ranges/index_mul.cc.h new file mode 100644 index 0000000..3bc69b8 --- /dev/null +++ b/src/include/ranges/index_mul.cc.h @@ -0,0 +1,132 @@ + +#ifndef __cxz_index_mul_cc_h__ +#define __cxz_index_mul_cc_h__ + +#include "index_mul.h" + +namespace CNORXZ +{ + /***************** + * MIndexMul * + *****************/ + + template + constexpr decltype(auto) MIndexMul::evalMX(const GMIndex& a, + const IndexInterface& b, + Isq is) + { + static_assert(sizeof...(Is) == sizeof...(Indices), "inconsistent index sequence"); + return MIndex( std::get(a.pack())..., b.THIS() ); + } + + template + constexpr decltype(auto) MIndexMul::evalXM(const IndexInterface& a, + const GMIndex& b, + Isq js) + { + static_assert(sizeof...(Js) == sizeof...(Indices), "inconsistent index sequence"); + return MIndex( a.THIS(), std::get(b.pack())... ); + } + + template + constexpr decltype(auto) MIndexMul::evalMM(const GMIndex& a, + const GMIndex& b, + Isq is, Isq js) + { + static_assert(sizeof...(Is) == sizeof...(Indices1), "inconsistent index sequence"); + static_assert(sizeof...(Js) == sizeof...(Indices2), "inconsistent index sequence"); + return MIndex( std::get(a.pack())..., + std::get(b.pack())... ); + } + + /********************* + * MIndexSptrMul * + *********************/ + + template + decltype(auto) MIndexSptrMul::evalMX(const Sptr>& a, + const Sptr& b, Isq is) + { + static_assert(sizeof...(Is) == sizeof...(Indices), "inconsistent index sequence"); + return std::make_shared>( std::get(a->pack())..., b ); + } + + template + decltype(auto) MIndexSptrMul::evalXM(const Sptr& a, + const Sptr>& b, + Isq js) + { + static_assert(sizeof...(Js) == sizeof...(Indices), "inconsistent index sequence"); + return std::make_shared>( a, std::get(b->pack())... ); + } + + template + decltype(auto) MIndexSptrMul::evalMM(const Sptr>& a, + const Sptr>& b, + Isq is, Isq js) + { + static_assert(sizeof...(Is) == sizeof...(Indices1), "inconsistent index sequence"); + static_assert(sizeof...(Js) == sizeof...(Indices2), "inconsistent index sequence"); + return MIndex( std::get(a->pack())..., + std::get(b->pack())... ); + } + + + /***************** + * operator* * + *****************/ + + template + constexpr decltype(auto) operator*(const IndexInterface& a, + const IndexInterface& b) + { + // special operations for DIndex / YIndex + if constexpr(index_dim::value == 1){ + if constexpr(index_dim::value == 1){ + return MIndex(a.THIS(),b.THIS()); + } + else { + return MIndexMul::evalXM(a, b.THIS()); + } + } + else { + if constexpr(index_dim::value == 1){ + return MIndexMul::evalMX(a.THIS(), b); + } + else { + return MIndexMul::evalMM(a.THIS(), b.THIS()); + } + } + } + + /*************** + * iptrMul * + ***************/ + + template + decltype(auto) iptrMul(const Sptr& a, const Sptr& b) + { + // special operations for DIndex / YIndex + if constexpr(index_dim::value == 1){ + if constexpr(index_dim::value == 1){ + return std::make_shared>(a, b); + } + else { + return MIndexSptrMul::evalXM(a, b); + } + } + else { + if constexpr(index_dim::value == 1){ + return MIndexSptrMul::evalMX(a, b); + } + else { + return MIndexSptrMul::evalMM(a, b); + } + } + } + +} + +#endif diff --git a/src/include/ranges/index_mul.h b/src/include/ranges/index_mul.h new file mode 100644 index 0000000..5a53595 --- /dev/null +++ b/src/include/ranges/index_mul.h @@ -0,0 +1,54 @@ + +#ifndef __cxz_index_mul_h__ +#define __cxz_index_mul_h__ + +#include "base/base.h" +#include "index_base.h" + +namespace CNORXZ +{ + struct MIndexMul + { + template + static constexpr decltype(auto) evalMX(const GMIndex& a, + const IndexInterface& b, + Isq is); + + template + static constexpr decltype(auto) evalXM(const IndexInterface& a, + const GMIndex& b, + Isq js); + + template + static constexpr decltype(auto) evalMM(const GMIndex& a, + const GMIndex& b, + Isq is, Isq js); + }; + + struct MIndexSptrMul + { + template + static decltype(auto) evalMX(const Sptr>& a, + const Sptr& b, Isq is); + + template + static decltype(auto) evalXM(const Sptr& a, const Sptr>& b, + Isq js); + + template + static decltype(auto) evalMM(const Sptr>& a, + const Sptr>& b, + Isq is, Isq js); + }; + + template + constexpr decltype(auto) operator*(const IndexInterface& a, + const IndexInterface& b); + + template + decltype(auto) iptrMul(const Sptr& a, const Sptr& b); +} + +#endif diff --git a/src/include/ranges/mrange.cc.h b/src/include/ranges/mrange.cc.h index a79fce8..10df5fc 100644 --- a/src/include/ranges/mrange.cc.h +++ b/src/include/ranges/mrange.cc.h @@ -449,69 +449,18 @@ namespace CNORXZ return mLexBlockSizes; } + template + decltype(auto) operator*(const Sptr>& a, const Sptr>& b) + { + return iptrMul(a, b); + } + template constexpr decltype(auto) mindex(const Sptr&... is) { return MIndex(is...); } - /***************** - * MIndexMul * - *****************/ - - template - constexpr decltype(auto) MIndexMul::evalMX(const GMIndex& a, - const IndexInterface& b, - Isq is) - { - static_assert(sizeof...(Is) == sizeof...(Indices), "inconsistent index sequence"); - return MIndex( std::get(a.pack())..., b.THIS() ); - } - - template - constexpr decltype(auto) MIndexMul::evalXM(const IndexInterface& a, - const GMIndex& b, - Isq js) - { - static_assert(sizeof...(Js) == sizeof...(Indices), "inconsistent index sequence"); - return MIndex( a.THIS(), std::get(b.pack())... ); - } - - template - constexpr decltype(auto) MIndexMul::evalMM(const GMIndex& a, - const GMIndex& b, - Isq is, Isq js) - { - static_assert(sizeof...(Is) == sizeof...(Indices1), "inconsistent index sequence"); - static_assert(sizeof...(Js) == sizeof...(Indices2), "inconsistent index sequence"); - return MIndex( std::get(a.pack())..., - std::get(b.pack())... ); - } - - // move to separate file!!! - template - constexpr decltype(auto) operator*(const IndexInterface& a, - const IndexInterface& b) - { - // special operations for DIndex / YIndex - if constexpr(index_dim::value == 1){ - if constexpr(index_dim::value == 1){ - return MIndex(a.THIS(),b.THIS()); - } - else { - return MIndexMul::evalXM(a, b.THIS()); - } - } - else { - if constexpr(index_dim::value == 1){ - return MIndexMul::evalMX(a.THIS(), b); - } - else { - return MIndexMul::evalMM(a.THIS(), b.THIS()); - } - } - } - /********************* * MRangeFactory * diff --git a/src/include/ranges/mrange.h b/src/include/ranges/mrange.h index d1ba6dd..6b09ade 100644 --- a/src/include/ranges/mrange.h +++ b/src/include/ranges/mrange.h @@ -111,6 +111,9 @@ namespace CNORXZ PMaxT mPMax; }; + template + decltype(auto) operator*(const Sptr>& a, const Sptr>& b); + //template //using MIndex = GMIndex; template @@ -128,33 +131,6 @@ namespace CNORXZ template constexpr decltype(auto) mindex(const Sptr&... is); - struct MIndexMul - { - template - static constexpr decltype(auto) evalMX(const GMIndex& a, - const IndexInterface& b, - Isq is); - - template - static constexpr decltype(auto) evalXM(const IndexInterface& a, - const GMIndex& b, - Isq js); - - template - static constexpr decltype(auto) evalMM(const GMIndex& a, - const GMIndex& b, - Isq is, Isq js); - }; - - // move to separate file!!! - template - constexpr decltype(auto) operator*(const IndexInterface& a, - const IndexInterface& b); - - //template - //constexpr decltype(auto) mindex(const GMIndex& a, - // Isq is, Isq js) const; - template class MRangeFactory : public RangeFactoryBase diff --git a/src/include/ranges/ranges.cc.h b/src/include/ranges/ranges.cc.h index fc79abd..b7beae8 100644 --- a/src/include/ranges/ranges.cc.h +++ b/src/include/ranges/ranges.cc.h @@ -6,3 +6,4 @@ #include "urange.cc.h" #include "crange.cc.h" #include "dindex.cc.h" +#include "index_mul.cc.h" diff --git a/src/include/ranges/ranges.h b/src/include/ranges/ranges.h index b9a799f..ea942c3 100644 --- a/src/include/ranges/ranges.h +++ b/src/include/ranges/ranges.h @@ -2,12 +2,12 @@ #include "range_base.h" #include "index_base.h" #include "mrange.h" -//#include "range_helper.h" #include "crange.h" //#include "subrange.h" //#include "value_range.h" #include "xindex.h" #include "yrange.h" #include "dindex.h" +#include "index_mul.h" #include "ranges.cc.h" diff --git a/src/include/ranges/urange.cc.h b/src/include/ranges/urange.cc.h index ae8405f..c0923b1 100644 --- a/src/include/ranges/urange.cc.h +++ b/src/include/ranges/urange.cc.h @@ -6,6 +6,7 @@ #include #include "urange.h" +#include "index_mul.h" #include "xpr/for.h" namespace CNORXZ @@ -143,6 +144,12 @@ namespace CNORXZ return For<0,Xpr,F>(this->pmax().val(), this->id(), xpr, std::forward(f)); } + template + decltype(auto) operator*(const Sptr>& a, const Sptr& b) + { + return iptrMul(a, b); + } + /********************** * URangeFactory * **********************/ diff --git a/src/include/ranges/urange.h b/src/include/ranges/urange.h index 904d87e..5ee64fc 100644 --- a/src/include/ranges/urange.h +++ b/src/include/ranges/urange.h @@ -55,6 +55,9 @@ namespace CNORXZ const MetaT* mMetaPtr; }; + template + decltype(auto) operator*(const Sptr>& a, const Sptr& b); + template class URangeFactory : public RangeFactoryBase { diff --git a/src/tests/operation_unit_test.cc b/src/tests/operation_unit_test.cc index b37b93d..fcb348d 100644 --- a/src/tests/operation_unit_test.cc +++ b/src/tests/operation_unit_test.cc @@ -63,10 +63,8 @@ namespace mCIa2 = std::make_shared(cra); mCIb1 = std::make_shared(crb); mCIb2 = std::make_shared(crb); - mCCa1a2 = std::make_shared(mCIa1,mCIa2); - mCCa2a1 = std::make_shared(mCIa2,mCIa1); - //mCCa1a2 = mCIa1*mCIa2; - //mCCa2a1 = mCIa2*mCIa1; + mCCa1a2 = mCIa1*mCIa2; + mCCa2a1 = mCIa2*mCIa1; mOCa1a2.init(mCCa1a2); mORa2a1.init(mData12.data(), mCCa2a1); }