diff --git a/src/include/conversions.h b/src/include/conversions.h new file mode 100644 index 0000000..87d1a1e --- /dev/null +++ b/src/include/conversions.h @@ -0,0 +1,64 @@ + +#ifndef __ma_conversions_h__ +#define __ma_conversions_h__ + +#include "multi_array.h" +#include "slice.h" + +namespace MultiArrayTools +{ + + namespace ConversionSizes + { + template + struct OrigSize + { + template + struct FromTo + { + static void check() { static_assert( not N % (sizeof(T) / sizeof(C)), "conversion does not fit" ); } + static constexpr size_t SIZE = N * sizeof(T) / sizeof(C); + }; + }; + + template <> + struct OrigSize<-1> + { + template + struct FromTo + { + static void check() {} + static constexpr size_t SIZE = -1; + }; + }; + } + + namespace + { + template + using SC = ConversionSizes::OrigSize::template FromTo; + + template + using SCR = SC; + + template + using SCRR = GenSingleRange; + } + + template + Slice> tcast(MultiArray& ma) + { + return Slice> + ( ma.range()->space(), reinterpret_cast( ma.data() ) ); + } + + template + ConstSlice> tcast(const MultiArray& ma) + { + return ConstSlice> + ( ma.range()->space(), reinterpret_cast( ma.data() ) ); + } + +} + +#endif diff --git a/src/include/type_operations.h b/src/include/type_operations.h index 6870dac..efa35c6 100644 --- a/src/include/type_operations.h +++ b/src/include/type_operations.h @@ -82,7 +82,60 @@ namespace MultiArrayTools std::transform(a.begin(), a.end(), b.begin(), a.begin(), std::plus()); return a; } - + + template + std::vector& operator-=(std::vector& a, const std::vector& b) + { + std::transform(a.begin(), a.end(), b.begin(), a.begin(), std::minus()); + return a; + } + + template + std::vector& operator*=(std::vector& a, const std::vector& b) + { + std::transform(a.begin(), a.end(), b.begin(), a.begin(), std::multiplies()); + return a; + } + + template + std::vector& operator/=(std::vector& a, const std::vector& b) + { + std::transform(a.begin(), a.end(), b.begin(), a.begin(), std::divides()); + return a; + } + + template + std::vector operator+(std::vector& a, const std::vector& b) + { + std::vector out(a.size()); + std::transform(a.begin(), a.end(), b.begin(), out.begin(), std::plus()); + return out; + } + + template + std::vector operator-(std::vector& a, const std::vector& b) + { + std::vector out(a.size()); + std::transform(a.begin(), a.end(), b.begin(), out.begin(), std::minus()); + return out; + } + + template + std::vector operator*(std::vector& a, const std::vector& b) + { + std::vector out(a.size()); + std::transform(a.begin(), a.end(), b.begin(), out.begin(), std::multiplies()); + return out; + } + + template + std::vector operator/(std::vector& a, const std::vector& b) + { + std::vector out(a.size()); + std::transform(a.begin(), a.end(), b.begin(), out.begin(), std::divides()); + return out; + } + template class OperationTemplate,OperationClass> : public OperationBase,OperationClass> { @@ -101,28 +154,47 @@ namespace MultiArrayTools friend OperationClass; }; - inline std::array& operator+=(std::array& a, const std::array& b) + template + inline std::array operator+(std::array& a, const std::array& b) { - std::get<0>(a) += std::get<0>(b); - std::get<1>(a) += std::get<1>(b); - return a; + std::array out; + for(size_t i = 0; i != N; ++i){ + out[i] = a[i] + b[i]; + } + return out; } + + template + inline std::array operator-(std::array& a, const std::array& b) + { + std::array out; + for(size_t i = 0; i != N; ++i){ + out[i] = a[i] - b[i]; + } + return out; + } + + template + inline std::array operator*(std::array& a, const std::array& b) + { + std::array out; + for(size_t i = 0; i != N; ++i){ + out[i] = a[i] * b[i]; + } + return out; + } + + template + inline std::array operator/(std::array& a, const std::array& b) + { + std::array out; + for(size_t i = 0; i != N; ++i){ + out[i] = a[i] / b[i]; + } + return out; + } + - inline std::array& operator+=(std::array& a, const std::array& b) - { - std::get<0>(a) += std::get<0>(b); - std::get<1>(a) += std::get<1>(b); - std::get<2>(a) += std::get<2>(b); - return a; - } - - inline std::tuple& operator+=(std::tuple& a, const std::tuple& b) - { - std::get<0>(a) += std::get<0>(b); - std::get<1>(a) += std::get<1>(b); - std::get<2>(a) += std::get<2>(b); - return a; - } } // namespace MultiArrayTools