diff --git a/src/include/multi_array.h b/src/include/multi_array.h index f975586..663d96f 100644 --- a/src/include/multi_array.h +++ b/src/include/multi_array.h @@ -7,6 +7,20 @@ namespace MultiArrayTools { + template + struct ArrayCatter; + + + template + struct ArrayCatter + { + template + static auto cat(const MultiArray& ma) + -> MultiArray + { + return ma; + } + }; template @@ -23,6 +37,7 @@ namespace MultiArrayTools MultiArray(const std::shared_ptr&... ranges, const std::vector& vec); MultiArray(const std::shared_ptr&... ranges, std::vector&& vec); MultiArray(const typename CRange::SpaceType& space); + MultiArray(const typename CRange::SpaceType& space, std::vector&& vec); // Only if ALL ranges have default extensions: //MultiArray(const std::vector& vec); @@ -51,18 +66,46 @@ namespace MultiArrayTools virtual const T* data() const override; virtual T* data() override; virtual std::vector& vdata() { return mCont; } + virtual const std::vector& vdata() const { return mCont; } + auto cat() const + -> decltype(ArrayCatter::cat(*this)); + operator T() const; template friend class MultiArray; private: + std::vector mCont; }; template using Scalar = MultiArray; + + template + struct ArrayCatter > + { + template + static auto cat(const MultiArray,Ranges...>& ma) + -> MultiArray + { + auto sma = *ma.begin(); + const size_t smas = sma.size(); + const size_t mas = ma.size(); + auto cr = ma.range()->cat(sma.range()); + std::vector ov; + ov.reserve(mas * smas); + + for(auto& x: ma){ + assert(x.size() == smas); + ov.insert(ov.end(), x.vdata().begin(), x.vdata().end()); + } + return MultiArray(cr->space(), std::move(ov)); + } + }; + } @@ -84,7 +127,20 @@ namespace MultiArrayTools { MAB::mInit = true; } - + + template + MultiArray::MultiArray(const typename CRange::SpaceType& space, + std::vector&& vec) : + MutableMultiArrayBase(space), + mCont(vec) + { + MAB::mInit = true; + if(mCont.size() > MAB::mRange->size()){ + mCont.erase(mCont.begin() + MAB::mRange->size(), mCont.end()); + } + } + + template MultiArray::MultiArray(const std::shared_ptr&... ranges) : MutableMultiArrayBase(ranges...), @@ -206,6 +262,13 @@ namespace MultiArrayTools static_assert( sizeof...(SRanges) == 0, "try to cast non-scalar type into scalar" ); return mCont[0]; } + + template + auto MultiArray::cat() const + -> decltype(ArrayCatter::cat(*this)) + { + return ArrayCatter::cat(*this); + } } #endif diff --git a/src/include/multi_array_base.h b/src/include/multi_array_base.h index b0fc699..7ead45e 100644 --- a/src/include/multi_array_base.h +++ b/src/include/multi_array_base.h @@ -26,6 +26,13 @@ namespace MultiArrayTools typedef ContainerRange CRange; typedef ContainerIndex IndexType; + protected: + bool mInit = false; + std::shared_ptr mRange; + std::shared_ptr mProtoI; + + public: + DEFAULT_MEMBERS(MultiArrayBase); MultiArrayBase(const std::shared_ptr&... ranges); MultiArrayBase(const typename CRange::SpaceType& space); @@ -54,11 +61,11 @@ namespace MultiArrayTools operator()(std::shared_ptr&... inds) const; virtual bool isInit() const; + + template + auto getRangePtr() const + -> decltype(mRange->template getPtr()); - protected: - bool mInit = false; - std::shared_ptr mRange; - std::shared_ptr mProtoI; }; template @@ -193,6 +200,14 @@ namespace MultiArrayTools return mInit; } + template + template + auto MultiArrayBase::getRangePtr() const + -> decltype(mRange->template getPtr()) + { + return mRange->template getPtr(); + } + /****************************** * MutableMultiArrayBase * diff --git a/src/include/ranges/container_range.h b/src/include/ranges/container_range.h index 3a97024..844bcec 100644 --- a/src/include/ranges/container_range.h +++ b/src/include/ranges/container_range.h @@ -494,6 +494,7 @@ namespace MultiArrayTools (*this)++; } } + return *this; } template @@ -509,6 +510,7 @@ namespace MultiArrayTools (*this)--; } } + return *this; } template diff --git a/src/include/ranges/multi_range.h b/src/include/ranges/multi_range.h index 4fded25..8f0b964 100644 --- a/src/include/ranges/multi_range.h +++ b/src/include/ranges/multi_range.h @@ -176,6 +176,10 @@ namespace MultiArrayTools virtual IndexType begin() const final; virtual IndexType end() const final; + template + auto cat(const std::shared_ptr >& erange) + -> std::shared_ptr >; + friend MultiRangeFactory; static constexpr bool defaultable = false; @@ -510,6 +514,17 @@ namespace MultiArrayTools i = size(); return i; } + + template + template + auto MultiRange::cat(const std::shared_ptr >& erange) + -> std::shared_ptr > + { + auto crange = std::tuple_cat(mSpace, erange->space()); + MultiRangeFactory rf(crange); + return std::dynamic_pointer_cast >(rf.create()); + } + } #endif