enable array cat

This commit is contained in:
Christian Zimmermann 2018-05-20 20:03:44 +02:00
parent 3367ab684a
commit 814665a6de
4 changed files with 100 additions and 5 deletions

View file

@ -7,6 +7,20 @@
namespace MultiArrayTools
{
template <typename T>
struct ArrayCatter;
template <typename T>
struct ArrayCatter
{
template <class... Ranges>
static auto cat(const MultiArray<T,Ranges...>& ma)
-> MultiArray<T,Ranges...>
{
return ma;
}
};
template <typename T, class... SRanges>
@ -23,6 +37,7 @@ namespace MultiArrayTools
MultiArray(const std::shared_ptr<SRanges>&... ranges, const std::vector<T>& vec);
MultiArray(const std::shared_ptr<SRanges>&... ranges, std::vector<T>&& vec);
MultiArray(const typename CRange::SpaceType& space);
MultiArray(const typename CRange::SpaceType& space, std::vector<T>&& vec);
// Only if ALL ranges have default extensions:
//MultiArray(const std::vector<T>& vec);
@ -51,18 +66,46 @@ namespace MultiArrayTools
virtual const T* data() const override;
virtual T* data() override;
virtual std::vector<T>& vdata() { return mCont; }
virtual const std::vector<T>& vdata() const { return mCont; }
auto cat() const
-> decltype(ArrayCatter<T>::cat(*this));
operator T() const;
template <typename U, class... SRanges2>
friend class MultiArray;
private:
std::vector<T> mCont;
};
template <typename T>
using Scalar = MultiArray<T>;
template <typename T, class... ERanges>
struct ArrayCatter<MultiArray<T,ERanges...> >
{
template <class... Ranges>
static auto cat(const MultiArray<MultiArray<T,ERanges...>,Ranges...>& ma)
-> MultiArray<T,Ranges...,ERanges...>
{
auto sma = *ma.begin();
const size_t smas = sma.size();
const size_t mas = ma.size();
auto cr = ma.range()->cat(sma.range());
std::vector<T> ov;
ov.reserve(mas * smas);
for(auto& x: ma){
assert(x.size() == smas);
ov.insert(ov.end(), x.vdata().begin(), x.vdata().end());
}
return MultiArray<T,Ranges...,ERanges...>(cr->space(), std::move(ov));
}
};
}
@ -84,7 +127,20 @@ namespace MultiArrayTools
{
MAB::mInit = true;
}
template <typename T, class... SRanges>
MultiArray<T,SRanges...>::MultiArray(const typename CRange::SpaceType& space,
std::vector<T>&& vec) :
MutableMultiArrayBase<T,SRanges...>(space),
mCont(vec)
{
MAB::mInit = true;
if(mCont.size() > MAB::mRange->size()){
mCont.erase(mCont.begin() + MAB::mRange->size(), mCont.end());
}
}
template <typename T, class... SRanges>
MultiArray<T,SRanges...>::MultiArray(const std::shared_ptr<SRanges>&... ranges) :
MutableMultiArrayBase<T,SRanges...>(ranges...),
@ -206,6 +262,13 @@ namespace MultiArrayTools
static_assert( sizeof...(SRanges) == 0, "try to cast non-scalar type into scalar" );
return mCont[0];
}
template <typename T, class... SRanges>
auto MultiArray<T,SRanges...>::cat() const
-> decltype(ArrayCatter<T>::cat(*this))
{
return ArrayCatter<T>::cat(*this);
}
}
#endif

View file

@ -26,6 +26,13 @@ namespace MultiArrayTools
typedef ContainerRange<T,SRanges...> CRange;
typedef ContainerIndex<T,typename SRanges::IndexType...> IndexType;
protected:
bool mInit = false;
std::shared_ptr<CRange> mRange;
std::shared_ptr<IndexType> mProtoI;
public:
DEFAULT_MEMBERS(MultiArrayBase);
MultiArrayBase(const std::shared_ptr<SRanges>&... ranges);
MultiArrayBase(const typename CRange::SpaceType& space);
@ -54,11 +61,11 @@ namespace MultiArrayTools
operator()(std::shared_ptr<typename SRanges::IndexType>&... inds) const;
virtual bool isInit() const;
template <size_t N>
auto getRangePtr() const
-> decltype(mRange->template getPtr<N>());
protected:
bool mInit = false;
std::shared_ptr<CRange> mRange;
std::shared_ptr<IndexType> mProtoI;
};
template <typename T, class... SRanges>
@ -193,6 +200,14 @@ namespace MultiArrayTools
return mInit;
}
template <typename T, class... SRanges>
template <size_t N>
auto MultiArrayBase<T,SRanges...>::getRangePtr() const
-> decltype(mRange->template getPtr<N>())
{
return mRange->template getPtr<N>();
}
/******************************
* MutableMultiArrayBase *

View file

@ -494,6 +494,7 @@ namespace MultiArrayTools
(*this)++;
}
}
return *this;
}
template <typename T, class... Indices>
@ -509,6 +510,7 @@ namespace MultiArrayTools
(*this)--;
}
}
return *this;
}
template <typename T, class... Indices>

View file

@ -176,6 +176,10 @@ namespace MultiArrayTools
virtual IndexType begin() const final;
virtual IndexType end() const final;
template <class... ERanges>
auto cat(const std::shared_ptr<MultiRange<ERanges...> >& erange)
-> std::shared_ptr<MultiRange<Ranges...,ERanges...> >;
friend MultiRangeFactory<Ranges...>;
static constexpr bool defaultable = false;
@ -510,6 +514,17 @@ namespace MultiArrayTools
i = size();
return i;
}
template <class... Ranges>
template <class... ERanges>
auto MultiRange<Ranges...>::cat(const std::shared_ptr<MultiRange<ERanges...> >& erange)
-> std::shared_ptr<MultiRange<Ranges...,ERanges...> >
{
auto crange = std::tuple_cat(mSpace, erange->space());
MultiRangeFactory<Ranges...,ERanges...> rf(crange);
return std::dynamic_pointer_cast<MultiRange<Ranges...,ERanges...> >(rf.create());
}
}
#endif