hl op: resolve lowest indices at runtime
This commit is contained in:
parent
65ecc27c3e
commit
dcce9a5eea
5 changed files with 87 additions and 21 deletions
|
@ -3,6 +3,7 @@
|
|||
|
||||
namespace MultiArrayTools
|
||||
{
|
||||
|
||||
template <typename T, class Op>
|
||||
DynamicO<T> mkDynOp1(const Op& op)
|
||||
{
|
||||
|
@ -238,12 +239,56 @@ namespace MultiArrayTools
|
|||
( std::array<std::shared_ptr<HighLevelOpBase<ROP>>,2>({mOp, in.mOp}) ) );
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
template <class... Indices>
|
||||
HighLevelOpHolder<ROP>& HighLevelOpHolder<ROP>::xassign(const HighLevelOpHolder& in,
|
||||
const std::shared_ptr<DynamicIndex>& di,
|
||||
const std::shared_ptr<Indices>&... is)
|
||||
{
|
||||
const size_t dim = di->dim();
|
||||
if(dim >= 2){
|
||||
auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2));
|
||||
auto ci2 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-1));
|
||||
assert(ci1 != nullptr);
|
||||
assert(ci2 != nullptr);
|
||||
auto odi = mkSubSpaceX(di, dim-2);
|
||||
auto mi = mkMIndex(is..., odi);
|
||||
this->assign(in, mi, ci1->getIndex(), ci2->getIndex());
|
||||
}
|
||||
else {
|
||||
assert(dim == 1);
|
||||
auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2));
|
||||
assert(ci1 != nullptr);
|
||||
auto odi = mkSubSpaceX(di, dim-1);
|
||||
auto mi = mkMIndex(is..., odi);
|
||||
this->assign(in, mi, ci1->getIndex());
|
||||
}
|
||||
//INDS<ROP,Indices...>::template CallHLOp<> call;
|
||||
//call.assign(*this, in, is..., di);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class Ind1, class Ind2, class... Inds>
|
||||
std::string printInd(const std::shared_ptr<Ind1>& ind1, const std::shared_ptr<Ind2>& ind2,
|
||||
const std::shared_ptr<Inds>&... inds)
|
||||
{
|
||||
return std::to_string(reinterpret_cast<std::intptr_t>(ind1.get())) + "(" +
|
||||
std::to_string(ind1->max()) + "), " + printInd(ind2, inds...);
|
||||
}
|
||||
|
||||
template <class Ind1>
|
||||
std::string printInd(const std::shared_ptr<Ind1>& ind1)
|
||||
{
|
||||
return std::to_string(reinterpret_cast<std::intptr_t>(ind1.get())) + "(" + std::to_string(ind1->max()) + ")";
|
||||
}
|
||||
|
||||
template <class ROP>
|
||||
template <class MIndex, class... Indices>
|
||||
HighLevelOpHolder<ROP>& HighLevelOpHolder<ROP>::assign(const HighLevelOpHolder& in,
|
||||
const std::shared_ptr<MIndex>& mi,
|
||||
const std::shared_ptr<Indices>&... inds)
|
||||
{
|
||||
//VCHECK(printInd(inds...));
|
||||
auto xx = mkArrayPtr<double>(nullr());
|
||||
auto& opr = *mOp->get();
|
||||
auto loop = mkPILoop
|
||||
|
@ -302,7 +347,7 @@ namespace MultiArrayTools
|
|||
#undef regFunc1
|
||||
#undef SP
|
||||
|
||||
|
||||
/*
|
||||
template <size_t N>
|
||||
template <class ITuple>
|
||||
inline void SetLInds<N>::mkLIT(const ITuple& itp, const std::shared_ptr<DynamicIndex>& di)
|
||||
|
@ -368,22 +413,6 @@ namespace MultiArrayTools
|
|||
return mDepth;
|
||||
}
|
||||
|
||||
std::shared_ptr<DynamicIndex> mkSubSpaceX(const std::shared_ptr<DynamicIndex>& di,
|
||||
size_t max)
|
||||
{
|
||||
auto& o = di->range()->orig();
|
||||
vector<std::shared_ptr<RangeBase>> ox(o.begin(),o.begin()+max);
|
||||
DynamicRangeFactory drf(ox);
|
||||
auto dr = createExplicit(drf);
|
||||
auto odi = getIndex(dr);
|
||||
vector<std::shared_ptr<IndexW>> iv;
|
||||
iv.reserve(max);
|
||||
for(size_t i = 0; i != max; ++i){
|
||||
iv.push_back(di->getP(i));
|
||||
}
|
||||
(*odi)(iv);
|
||||
return odi;
|
||||
}
|
||||
|
||||
template <class ROP, class... Indices>
|
||||
template <class... LIndices>
|
||||
|
@ -428,4 +457,6 @@ namespace MultiArrayTools
|
|||
plus(target, source, mi, itp);
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ namespace MultiArrayTools
|
|||
template <typename T, class Op>
|
||||
DynamicO<T> mkDynOp1(const Op& op);
|
||||
|
||||
std::shared_ptr<DynamicIndex> mkSubSpaceX(const std::shared_ptr<DynamicIndex>& di, size_t max);
|
||||
|
||||
template <class ROP>
|
||||
class HighLevelOpBase
|
||||
{
|
||||
|
@ -153,6 +155,11 @@ namespace MultiArrayTools
|
|||
HighLevelOpHolder operator-(const HighLevelOpHolder& in) const;
|
||||
HighLevelOpHolder operator/(const HighLevelOpHolder& in) const;
|
||||
|
||||
template <class... Indices>
|
||||
HighLevelOpHolder& xassign(const HighLevelOpHolder& in,
|
||||
const std::shared_ptr<DynamicIndex>& di,
|
||||
const std::shared_ptr<Indices>&... is);
|
||||
|
||||
template <class MIndex, class... Indices>
|
||||
HighLevelOpHolder& assign(const HighLevelOpHolder& in,
|
||||
const std::shared_ptr<MIndex>& mi,
|
||||
|
@ -174,7 +181,7 @@ namespace MultiArrayTools
|
|||
#include "extensions/math.h"
|
||||
#undef regFunc1
|
||||
#undef SP
|
||||
|
||||
/*
|
||||
template <size_t N>
|
||||
struct SetLInds
|
||||
{
|
||||
|
@ -248,7 +255,7 @@ namespace MultiArrayTools
|
|||
const std::shared_ptr<DynamicIndex>& di) const override final;
|
||||
};
|
||||
};
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -7,6 +7,7 @@ set(libmultiarray_a_SOURCES
|
|||
${CMAKE_SOURCE_DIR}/src/lib/ranges/type_map.cc
|
||||
${CMAKE_SOURCE_DIR}/src/lib/ranges/multi_range_factory_product_map.cc
|
||||
${CMAKE_SOURCE_DIR}/src/lib/map_range_factory_product_map.cc
|
||||
${CMAKE_SOURCE_DIR}/src/lib/high_level_operation.cc
|
||||
)
|
||||
|
||||
file(GLOB cc_files "${CMAKE_SOURCE_DIR}/src/lib/ranges/range_types/*.cc")
|
||||
|
|
24
src/lib/high_level_operation.cc
Normal file
24
src/lib/high_level_operation.cc
Normal file
|
@ -0,0 +1,24 @@
|
|||
|
||||
#include "multi_array_header.h"
|
||||
#include "high_level_operation.h"
|
||||
|
||||
namespace MultiArrayTools
|
||||
{
|
||||
std::shared_ptr<DynamicIndex> mkSubSpaceX(const std::shared_ptr<DynamicIndex>& di,
|
||||
size_t max)
|
||||
{
|
||||
auto& o = di->range()->orig();
|
||||
vector<std::shared_ptr<RangeBase>> ox(o.begin(),o.begin()+max);
|
||||
DynamicRangeFactory drf(ox);
|
||||
auto dr = createExplicit(drf);
|
||||
auto odi = getIndex(dr);
|
||||
vector<std::shared_ptr<IndexW>> iv;
|
||||
iv.reserve(max);
|
||||
for(size_t i = 0; i != max; ++i){
|
||||
iv.push_back(di->getP(i));
|
||||
}
|
||||
(*odi)(iv);
|
||||
return odi;
|
||||
}
|
||||
|
||||
}
|
|
@ -287,6 +287,8 @@ namespace
|
|||
(*di4a)(svec({"ia_1","ib_1"}));
|
||||
auto ic_1 = DynamicIndex::getIndexFromMap<CI>("ic_1");
|
||||
auto ic_2 = DynamicIndex::getIndexFromMap<CI>("ic_2");
|
||||
//VCHECK(reinterpret_cast<std::intptr_t>(ic_1.get()));
|
||||
//VCHECK(reinterpret_cast<std::intptr_t>(ic_2.get()));
|
||||
|
||||
auto resx1 = res1;
|
||||
auto resx2 = res1;
|
||||
|
@ -305,7 +307,8 @@ namespace
|
|||
auto hop2 = hl_exp(hop1);
|
||||
auto hop4 = hop3 * hop2;
|
||||
auto hopr = mkHLO(resx4(i1,di4));
|
||||
hopr.assign( hop4, mi, ic_1, ic_2 );
|
||||
//hopr.assign( hop4, mi, ic_1, ic_2 );
|
||||
hopr.xassign( hop4, di4, i1 );
|
||||
|
||||
auto i2_1 = imap.at("i2_1");
|
||||
auto i2_2 = imap.at("i2_2");
|
||||
|
|
Loading…
Reference in a new issue