enable array cat

This commit is contained in:
Christian Zimmermann 2018-05-20 20:03:44 +02:00
parent 3367ab684a
commit 814665a6de
4 changed files with 100 additions and 5 deletions

View file

@ -7,6 +7,20 @@
namespace MultiArrayTools namespace MultiArrayTools
{ {
template <typename T>
struct ArrayCatter;
template <typename T>
struct ArrayCatter
{
template <class... Ranges>
static auto cat(const MultiArray<T,Ranges...>& ma)
-> MultiArray<T,Ranges...>
{
return ma;
}
};
template <typename T, class... SRanges> template <typename T, class... SRanges>
@ -23,6 +37,7 @@ namespace MultiArrayTools
MultiArray(const std::shared_ptr<SRanges>&... ranges, const std::vector<T>& vec); MultiArray(const std::shared_ptr<SRanges>&... ranges, const std::vector<T>& vec);
MultiArray(const std::shared_ptr<SRanges>&... ranges, std::vector<T>&& vec); MultiArray(const std::shared_ptr<SRanges>&... ranges, std::vector<T>&& vec);
MultiArray(const typename CRange::SpaceType& space); MultiArray(const typename CRange::SpaceType& space);
MultiArray(const typename CRange::SpaceType& space, std::vector<T>&& vec);
// Only if ALL ranges have default extensions: // Only if ALL ranges have default extensions:
//MultiArray(const std::vector<T>& vec); //MultiArray(const std::vector<T>& vec);
@ -51,6 +66,10 @@ namespace MultiArrayTools
virtual const T* data() const override; virtual const T* data() const override;
virtual T* data() override; virtual T* data() override;
virtual std::vector<T>& vdata() { return mCont; } virtual std::vector<T>& vdata() { return mCont; }
virtual const std::vector<T>& vdata() const { return mCont; }
auto cat() const
-> decltype(ArrayCatter<T>::cat(*this));
operator T() const; operator T() const;
@ -58,12 +77,36 @@ namespace MultiArrayTools
friend class MultiArray; friend class MultiArray;
private: private:
std::vector<T> mCont; std::vector<T> mCont;
}; };
template <typename T> template <typename T>
using Scalar = MultiArray<T>; using Scalar = MultiArray<T>;
template <typename T, class... ERanges>
struct ArrayCatter<MultiArray<T,ERanges...> >
{
template <class... Ranges>
static auto cat(const MultiArray<MultiArray<T,ERanges...>,Ranges...>& ma)
-> MultiArray<T,Ranges...,ERanges...>
{
auto sma = *ma.begin();
const size_t smas = sma.size();
const size_t mas = ma.size();
auto cr = ma.range()->cat(sma.range());
std::vector<T> ov;
ov.reserve(mas * smas);
for(auto& x: ma){
assert(x.size() == smas);
ov.insert(ov.end(), x.vdata().begin(), x.vdata().end());
}
return MultiArray<T,Ranges...,ERanges...>(cr->space(), std::move(ov));
}
};
} }
/* ========================= * /* ========================= *
@ -85,6 +128,19 @@ namespace MultiArrayTools
MAB::mInit = true; MAB::mInit = true;
} }
template <typename T, class... SRanges>
MultiArray<T,SRanges...>::MultiArray(const typename CRange::SpaceType& space,
std::vector<T>&& vec) :
MutableMultiArrayBase<T,SRanges...>(space),
mCont(vec)
{
MAB::mInit = true;
if(mCont.size() > MAB::mRange->size()){
mCont.erase(mCont.begin() + MAB::mRange->size(), mCont.end());
}
}
template <typename T, class... SRanges> template <typename T, class... SRanges>
MultiArray<T,SRanges...>::MultiArray(const std::shared_ptr<SRanges>&... ranges) : MultiArray<T,SRanges...>::MultiArray(const std::shared_ptr<SRanges>&... ranges) :
MutableMultiArrayBase<T,SRanges...>(ranges...), MutableMultiArrayBase<T,SRanges...>(ranges...),
@ -206,6 +262,13 @@ namespace MultiArrayTools
static_assert( sizeof...(SRanges) == 0, "try to cast non-scalar type into scalar" ); static_assert( sizeof...(SRanges) == 0, "try to cast non-scalar type into scalar" );
return mCont[0]; return mCont[0];
} }
template <typename T, class... SRanges>
auto MultiArray<T,SRanges...>::cat() const
-> decltype(ArrayCatter<T>::cat(*this))
{
return ArrayCatter<T>::cat(*this);
}
} }
#endif #endif

View file

@ -26,6 +26,13 @@ namespace MultiArrayTools
typedef ContainerRange<T,SRanges...> CRange; typedef ContainerRange<T,SRanges...> CRange;
typedef ContainerIndex<T,typename SRanges::IndexType...> IndexType; typedef ContainerIndex<T,typename SRanges::IndexType...> IndexType;
protected:
bool mInit = false;
std::shared_ptr<CRange> mRange;
std::shared_ptr<IndexType> mProtoI;
public:
DEFAULT_MEMBERS(MultiArrayBase); DEFAULT_MEMBERS(MultiArrayBase);
MultiArrayBase(const std::shared_ptr<SRanges>&... ranges); MultiArrayBase(const std::shared_ptr<SRanges>&... ranges);
MultiArrayBase(const typename CRange::SpaceType& space); MultiArrayBase(const typename CRange::SpaceType& space);
@ -55,10 +62,10 @@ namespace MultiArrayTools
virtual bool isInit() const; virtual bool isInit() const;
protected: template <size_t N>
bool mInit = false; auto getRangePtr() const
std::shared_ptr<CRange> mRange; -> decltype(mRange->template getPtr<N>());
std::shared_ptr<IndexType> mProtoI;
}; };
template <typename T, class... SRanges> template <typename T, class... SRanges>
@ -193,6 +200,14 @@ namespace MultiArrayTools
return mInit; return mInit;
} }
template <typename T, class... SRanges>
template <size_t N>
auto MultiArrayBase<T,SRanges...>::getRangePtr() const
-> decltype(mRange->template getPtr<N>())
{
return mRange->template getPtr<N>();
}
/****************************** /******************************
* MutableMultiArrayBase * * MutableMultiArrayBase *

View file

@ -494,6 +494,7 @@ namespace MultiArrayTools
(*this)++; (*this)++;
} }
} }
return *this;
} }
template <typename T, class... Indices> template <typename T, class... Indices>
@ -509,6 +510,7 @@ namespace MultiArrayTools
(*this)--; (*this)--;
} }
} }
return *this;
} }
template <typename T, class... Indices> template <typename T, class... Indices>

View file

@ -176,6 +176,10 @@ namespace MultiArrayTools
virtual IndexType begin() const final; virtual IndexType begin() const final;
virtual IndexType end() const final; virtual IndexType end() const final;
template <class... ERanges>
auto cat(const std::shared_ptr<MultiRange<ERanges...> >& erange)
-> std::shared_ptr<MultiRange<Ranges...,ERanges...> >;
friend MultiRangeFactory<Ranges...>; friend MultiRangeFactory<Ranges...>;
static constexpr bool defaultable = false; static constexpr bool defaultable = false;
@ -510,6 +514,17 @@ namespace MultiArrayTools
i = size(); i = size();
return i; return i;
} }
template <class... Ranges>
template <class... ERanges>
auto MultiRange<Ranges...>::cat(const std::shared_ptr<MultiRange<ERanges...> >& erange)
-> std::shared_ptr<MultiRange<Ranges...,ERanges...> >
{
auto crange = std::tuple_cat(mSpace, erange->space());
MultiRangeFactory<Ranges...,ERanges...> rf(crange);
return std::dynamic_pointer_cast<MultiRange<Ranges...,ERanges...> >(rf.create());
}
} }
#endif #endif