hl op: resolve lowest indices at runtime

This commit is contained in:
Christian Zimmermann 2020-09-11 13:12:37 +02:00
parent 65ecc27c3e
commit dcce9a5eea
5 changed files with 87 additions and 21 deletions

View file

@ -3,6 +3,7 @@
namespace MultiArrayTools namespace MultiArrayTools
{ {
template <typename T, class Op> template <typename T, class Op>
DynamicO<T> mkDynOp1(const Op& op) DynamicO<T> mkDynOp1(const Op& op)
{ {
@ -238,12 +239,56 @@ namespace MultiArrayTools
( std::array<std::shared_ptr<HighLevelOpBase<ROP>>,2>({mOp, in.mOp}) ) ); ( std::array<std::shared_ptr<HighLevelOpBase<ROP>>,2>({mOp, in.mOp}) ) );
} }
template <class ROP>
template <class... Indices>
HighLevelOpHolder<ROP>& HighLevelOpHolder<ROP>::xassign(const HighLevelOpHolder& in,
const std::shared_ptr<DynamicIndex>& di,
const std::shared_ptr<Indices>&... is)
{
const size_t dim = di->dim();
if(dim >= 2){
auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2));
auto ci2 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-1));
assert(ci1 != nullptr);
assert(ci2 != nullptr);
auto odi = mkSubSpaceX(di, dim-2);
auto mi = mkMIndex(is..., odi);
this->assign(in, mi, ci1->getIndex(), ci2->getIndex());
}
else {
assert(dim == 1);
auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2));
assert(ci1 != nullptr);
auto odi = mkSubSpaceX(di, dim-1);
auto mi = mkMIndex(is..., odi);
this->assign(in, mi, ci1->getIndex());
}
//INDS<ROP,Indices...>::template CallHLOp<> call;
//call.assign(*this, in, is..., di);
return *this;
}
template <class Ind1, class Ind2, class... Inds>
std::string printInd(const std::shared_ptr<Ind1>& ind1, const std::shared_ptr<Ind2>& ind2,
const std::shared_ptr<Inds>&... inds)
{
return std::to_string(reinterpret_cast<std::intptr_t>(ind1.get())) + "(" +
std::to_string(ind1->max()) + "), " + printInd(ind2, inds...);
}
template <class Ind1>
std::string printInd(const std::shared_ptr<Ind1>& ind1)
{
return std::to_string(reinterpret_cast<std::intptr_t>(ind1.get())) + "(" + std::to_string(ind1->max()) + ")";
}
template <class ROP> template <class ROP>
template <class MIndex, class... Indices> template <class MIndex, class... Indices>
HighLevelOpHolder<ROP>& HighLevelOpHolder<ROP>::assign(const HighLevelOpHolder& in, HighLevelOpHolder<ROP>& HighLevelOpHolder<ROP>::assign(const HighLevelOpHolder& in,
const std::shared_ptr<MIndex>& mi, const std::shared_ptr<MIndex>& mi,
const std::shared_ptr<Indices>&... inds) const std::shared_ptr<Indices>&... inds)
{ {
//VCHECK(printInd(inds...));
auto xx = mkArrayPtr<double>(nullr()); auto xx = mkArrayPtr<double>(nullr());
auto& opr = *mOp->get(); auto& opr = *mOp->get();
auto loop = mkPILoop auto loop = mkPILoop
@ -302,7 +347,7 @@ namespace MultiArrayTools
#undef regFunc1 #undef regFunc1
#undef SP #undef SP
/*
template <size_t N> template <size_t N>
template <class ITuple> template <class ITuple>
inline void SetLInds<N>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di) inline void SetLInds<N>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di)
@ -368,22 +413,6 @@ namespace MultiArrayTools
return mDepth; return mDepth;
} }
std::shared_ptr<DynamicIndex> mkSubSpaceX(const std::shared_ptr<DynamicIndex>& di,
size_t max)
{
auto& o = di->range()->orig();
vector<std::shared_ptr<RangeBase>> ox(o.begin(),o.begin()+max);
DynamicRangeFactory drf(ox);
auto dr = createExplicit(drf);
auto odi = getIndex(dr);
vector<std::shared_ptr<IndexW>> iv;
iv.reserve(max);
for(size_t i = 0; i != max; ++i){
iv.push_back(di->getP(i));
}
(*odi)(iv);
return odi;
}
template <class ROP, class... Indices> template <class ROP, class... Indices>
template <class... LIndices> template <class... LIndices>
@ -428,4 +457,6 @@ namespace MultiArrayTools
plus(target, source, mi, itp); plus(target, source, mi, itp);
} }
} }
*/
} }

View file

@ -16,6 +16,8 @@ namespace MultiArrayTools
template <typename T, class Op> template <typename T, class Op>
DynamicO<T> mkDynOp1(const Op& op); DynamicO<T> mkDynOp1(const Op& op);
std::shared_ptr<DynamicIndex> mkSubSpaceX(const std::shared_ptr<DynamicIndex>& di, size_t max);
template <class ROP> template <class ROP>
class HighLevelOpBase class HighLevelOpBase
{ {
@ -153,6 +155,11 @@ namespace MultiArrayTools
HighLevelOpHolder operator-(const HighLevelOpHolder& in) const; HighLevelOpHolder operator-(const HighLevelOpHolder& in) const;
HighLevelOpHolder operator/(const HighLevelOpHolder& in) const; HighLevelOpHolder operator/(const HighLevelOpHolder& in) const;
template <class... Indices>
HighLevelOpHolder& xassign(const HighLevelOpHolder& in,
const std::shared_ptr<DynamicIndex>& di,
const std::shared_ptr<Indices>&... is);
template <class MIndex, class... Indices> template <class MIndex, class... Indices>
HighLevelOpHolder& assign(const HighLevelOpHolder& in, HighLevelOpHolder& assign(const HighLevelOpHolder& in,
const std::shared_ptr<MIndex>& mi, const std::shared_ptr<MIndex>& mi,
@ -174,7 +181,7 @@ namespace MultiArrayTools
#include "extensions/math.h" #include "extensions/math.h"
#undef regFunc1 #undef regFunc1
#undef SP #undef SP
/*
template <size_t N> template <size_t N>
struct SetLInds struct SetLInds
{ {
@ -248,7 +255,7 @@ namespace MultiArrayTools
const std::shared_ptr<DynamicIndex>& di) const override final; const std::shared_ptr<DynamicIndex>& di) const override final;
}; };
}; };
*/
} }
#endif #endif

View file

@ -7,6 +7,7 @@ set(libmultiarray_a_SOURCES
${CMAKE_SOURCE_DIR}/src/lib/ranges/type_map.cc ${CMAKE_SOURCE_DIR}/src/lib/ranges/type_map.cc
${CMAKE_SOURCE_DIR}/src/lib/ranges/multi_range_factory_product_map.cc ${CMAKE_SOURCE_DIR}/src/lib/ranges/multi_range_factory_product_map.cc
${CMAKE_SOURCE_DIR}/src/lib/map_range_factory_product_map.cc ${CMAKE_SOURCE_DIR}/src/lib/map_range_factory_product_map.cc
${CMAKE_SOURCE_DIR}/src/lib/high_level_operation.cc
) )
file(GLOB cc_files "${CMAKE_SOURCE_DIR}/src/lib/ranges/range_types/*.cc") file(GLOB cc_files "${CMAKE_SOURCE_DIR}/src/lib/ranges/range_types/*.cc")

View file

@ -0,0 +1,24 @@
#include "multi_array_header.h"
#include "high_level_operation.h"
namespace MultiArrayTools
{
std::shared_ptr<DynamicIndex> mkSubSpaceX(const std::shared_ptr<DynamicIndex>& di,
size_t max)
{
auto& o = di->range()->orig();
vector<std::shared_ptr<RangeBase>> ox(o.begin(),o.begin()+max);
DynamicRangeFactory drf(ox);
auto dr = createExplicit(drf);
auto odi = getIndex(dr);
vector<std::shared_ptr<IndexW>> iv;
iv.reserve(max);
for(size_t i = 0; i != max; ++i){
iv.push_back(di->getP(i));
}
(*odi)(iv);
return odi;
}
}

View file

@ -287,6 +287,8 @@ namespace
(*di4a)(svec({"ia_1","ib_1"})); (*di4a)(svec({"ia_1","ib_1"}));
auto ic_1 = DynamicIndex::getIndexFromMap<CI>("ic_1"); auto ic_1 = DynamicIndex::getIndexFromMap<CI>("ic_1");
auto ic_2 = DynamicIndex::getIndexFromMap<CI>("ic_2"); auto ic_2 = DynamicIndex::getIndexFromMap<CI>("ic_2");
//VCHECK(reinterpret_cast<std::intptr_t>(ic_1.get()));
//VCHECK(reinterpret_cast<std::intptr_t>(ic_2.get()));
auto resx1 = res1; auto resx1 = res1;
auto resx2 = res1; auto resx2 = res1;
@ -305,7 +307,8 @@ namespace
auto hop2 = hl_exp(hop1); auto hop2 = hl_exp(hop1);
auto hop4 = hop3 * hop2; auto hop4 = hop3 * hop2;
auto hopr = mkHLO(resx4(i1,di4)); auto hopr = mkHLO(resx4(i1,di4));
hopr.assign( hop4, mi, ic_1, ic_2 ); //hopr.assign( hop4, mi, ic_1, ic_2 );
hopr.xassign( hop4, di4, i1 );
auto i2_1 = imap.at("i2_1"); auto i2_1 = imap.at("i2_1");
auto i2_2 = imap.at("i2_2"); auto i2_2 = imap.at("i2_2");