reduced range in index wrapper (to be called in high level op)

This commit is contained in:
Christian Zimmermann 2020-09-11 23:29:58 +02:00
parent dcce9a5eea
commit f0354455fd
3 changed files with 43 additions and 11 deletions

View file

@ -247,21 +247,25 @@ namespace MultiArrayTools
{ {
const size_t dim = di->dim(); const size_t dim = di->dim();
if(dim >= 2){ if(dim >= 2){
auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2)); auto ci1 = di->getP(dim-2)->reduced();
auto ci2 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-1)); auto ci2 = di->getP(dim-1)->reduced();
//auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2));
//auto ci2 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-1));
assert(ci1 != nullptr); assert(ci1 != nullptr);
assert(ci2 != nullptr); assert(ci2 != nullptr);
auto odi = mkSubSpaceX(di, dim-2); auto odi = mkSubSpaceX(di, dim-2);
auto mi = mkMIndex(is..., odi); auto mi = mkMIndex(is..., odi);
this->assign(in, mi, ci1->getIndex(), ci2->getIndex()); //this->assign(in, mi, ci1->getIndex(), ci2->getIndex());
this->assign(in, mi, ci1, ci2);
} }
else { else {
assert(dim == 1); assert(dim == 1);
auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2)); //auto ci1 = std::dynamic_pointer_cast<IndexWrapper<CI>>(di->getP(dim-2));
auto ci1 = di->getP(dim-2)->reduced();
assert(ci1 != nullptr); assert(ci1 != nullptr);
auto odi = mkSubSpaceX(di, dim-1); auto odi = mkSubSpaceX(di, dim-1);
auto mi = mkMIndex(is..., odi); auto mi = mkMIndex(is..., odi);
this->assign(in, mi, ci1->getIndex()); this->assign(in, mi, ci1);
} }
//INDS<ROP,Indices...>::template CallHLOp<> call; //INDS<ROP,Indices...>::template CallHLOp<> call;
//call.assign(*this, in, is..., di); //call.assign(*this, in, is..., di);

View file

@ -6,7 +6,13 @@ namespace MultiArrayTools
{ {
template <class Index> template <class Index>
IndexWrapper<Index>::IndexWrapper(const std::shared_ptr<Index>& i) : mI(i) {} IndexWrapper<Index>::IndexWrapper(const std::shared_ptr<Index>& i) : mI(i)
{
ClassicRF crf(mI->max());
mCI = std::make_shared<ClassicIndex>
( std::dynamic_pointer_cast<ClassicRange>( crf.create() ) );
(*mCI) = mI->pos();
}
template <class Index> template <class Index>
IndexType IndexWrapper<Index>::type() const IndexType IndexWrapper<Index>::type() const
@ -111,7 +117,11 @@ namespace MultiArrayTools
template <class Index> template <class Index>
size_t IndexWrapper<Index>::getStepSizeComp(std::intptr_t j) const size_t IndexWrapper<Index>::getStepSizeComp(std::intptr_t j) const
{ {
return MultiArrayHelper::getStepSize(*mI, j); size_t out = MultiArrayHelper::getStepSize(*mI, j);
if(out == 0){
out = MultiArrayHelper::getStepSize(*mCI, j);
}
return out;
} }
template <class Index> template <class Index>
@ -143,17 +153,31 @@ namespace MultiArrayTools
{ {
return std::make_shared<IndexWrapper>( std::make_shared<Index>( *mI ) ); return std::make_shared<IndexWrapper>( std::make_shared<Index>( *mI ) );
} }
/*
template <class Index> template <class Index>
RegIndInfo IndexWrapper<Index>::regN() const RegIndInfo IndexWrapper<Index>::regN() const
{ {
RegIndInfo out; RegIndInfo out;
return out.set(mI); return out.set(mI);
} }
*/
template <class Index>
std::shared_ptr<Index> IndexWrapper<Index>::getIndex() const
{
return mI;
}
template <class Index>
std::shared_ptr<ClassicIndex> IndexWrapper<Index>::reduced() const
{
(*mCI) = mI->pos();
return mCI;
}
template <class Index> template <class Index>
inline std::shared_ptr<IndexWrapperBase> mkIndexWrapper(const Index& i) inline std::shared_ptr<IndexWrapperBase> mkIndexWrapper(const Index& i)
{ {
return std::make_shared<IndexWrapper<Index>>(std::make_shared<Index>(i)); return std::make_shared<IndexWrapper<Index>>(std::make_shared<Index>(i));
} }
} }

View file

@ -50,7 +50,7 @@ namespace MultiArrayTools
virtual std::shared_ptr<IndexWrapperBase> duplicate() const = 0; virtual std::shared_ptr<IndexWrapperBase> duplicate() const = 0;
virtual RegIndInfo regN() const = 0; //virtual RegIndInfo regN() const = 0;
//virtual DynamicMetaT meta() const = 0; //virtual DynamicMetaT meta() const = 0;
//virtual const DynamicMetaT* metaPtr() const = 0; //virtual const DynamicMetaT* metaPtr() const = 0;
//virtual AbstractIW& at(const U& metaPos) = 0; //virtual AbstractIW& at(const U& metaPos) = 0;
@ -73,6 +73,8 @@ namespace MultiArrayTools
std::shared_ptr<IndexWrapperBase> duplicateI() const std::shared_ptr<IndexWrapperBase> duplicateI() const
{ return std::dynamic_pointer_cast<IndexWrapperBase>( this->duplicate() ); } { return std::dynamic_pointer_cast<IndexWrapperBase>( this->duplicate() ); }
*/ */
virtual std::shared_ptr<ClassicIndex> reduced() const = 0;
}; };
typedef IndexWrapperBase IndexW; typedef IndexWrapperBase IndexW;
@ -91,6 +93,7 @@ namespace MultiArrayTools
private: private:
std::shared_ptr<Index> mI; std::shared_ptr<Index> mI;
std::shared_ptr<ClassicIndex> mCI; // reduced;
public: public:
IndexWrapper(const IndexWrapper& in) = default; IndexWrapper(const IndexWrapper& in) = default;
@ -136,9 +139,10 @@ namespace MultiArrayTools
virtual DynamicExpression iforh(size_t step, DynamicExpression ex) const override final; virtual DynamicExpression iforh(size_t step, DynamicExpression ex) const override final;
virtual std::shared_ptr<IndexWrapperBase> duplicate() const override final; virtual std::shared_ptr<IndexWrapperBase> duplicate() const override final;
virtual RegIndInfo regN() const override final; //virtual RegIndInfo regN() const override final;
std::shared_ptr<Index> getIndex() const { return mI; } std::shared_ptr<Index> getIndex() const;
virtual std::shared_ptr<ClassicIndex> reduced() const override final;
}; };
/* /*