From e46abff94ca009b6505f7757331e4f41e73184f7 Mon Sep 17 00:00:00 2001 From: Christian Zimmermann Date: Fri, 15 Dec 2017 14:47:02 +0100 Subject: [PATCH] spin range contraction test + some simplifying tools --- src/container_range.h | 7 +-- src/helper_tools.h | 41 +++++++++++++++++ src/multi_array.h | 5 ++ src/multi_range.h | 5 +- src/op_unit_test.cc | 89 ++++++++++++++++++++++++++++-------- src/pack_num.h | 13 ++++++ src/range_types/header.h | 6 ++- src/range_types/spin_range.h | 51 +++++++++++---------- src/single_range.h | 2 + 9 files changed, 168 insertions(+), 51 deletions(-) diff --git a/src/container_range.h b/src/container_range.h index bd24469..550e698 100644 --- a/src/container_range.h +++ b/src/container_range.h @@ -188,7 +188,7 @@ namespace MultiArrayTools typedef ContainerRange oType; - ContainerRangeFactory() = delete; + ContainerRangeFactory(); ContainerRangeFactory(const std::shared_ptr&... rs); ContainerRangeFactory(const typename ContainerRange::SpaceType& space); @@ -238,7 +238,8 @@ namespace MultiArrayTools virtual std::shared_ptr index() const override; friend ContainerRangeFactory; - + + static const bool defaultable = false; }; } // end namespace MultiArrayTools @@ -318,7 +319,7 @@ namespace MultiArrayTools { mProd = std::shared_ptr >( new ContainerRange( rs... ) ); } - + template ContainerRangeFactory:: ContainerRangeFactory(const typename ContainerRange::SpaceType& space) diff --git a/src/helper_tools.h b/src/helper_tools.h index 15ca256..8f53e4e 100644 --- a/src/helper_tools.h +++ b/src/helper_tools.h @@ -10,6 +10,19 @@ namespace MultiArrayTools template auto getIndex(std::shared_ptr range) -> std::shared_ptr; + + // only if 'RangeType' is defaultable and unique (Singleton) + template + auto getIndex() -> std::shared_ptr; + + template + auto mkMulti(std::shared_ptr... ranges) + -> std::shared_ptr >; + + template + auto mkMIndex(std::shared_ptr... indices) + -> decltype( getIndex( mkMulti( indices.range()... ) ) ); + } @@ -26,6 +39,34 @@ namespace MultiArrayTools return std::dynamic_pointer_cast > ( range->index() )->get(); } + + template + auto getIndex() -> std::shared_ptr + { + static_assert( RangeType::defaultable, + /*typeid(typename RangeType).name() + */" is not defaultable" ); + static auto f = RangeType::factory(); + static auto r = std::dynamic_pointer_cast( f.create() ); + return std::dynamic_pointer_cast > + ( r->index() )->get(); + } + + template + auto mkMulti(std::shared_ptr... ranges) + -> std::shared_ptr > + { + MultiRangeFactory mrf( ranges... ); + return std::dynamic_pointer_cast >( mrf.create() ); + } + + template + auto mkMIndex(std::shared_ptr... indices) + -> decltype( getIndex( mkMulti( indices->range()... ) ) ) + { + auto mi = getIndex( mkMulti( indices->range()... ) ); + (*mi)( indices... ); + return mi; + } } #endif diff --git a/src/multi_array.h b/src/multi_array.h index 32681f7..bbb6733 100644 --- a/src/multi_array.h +++ b/src/multi_array.h @@ -202,6 +202,10 @@ namespace MultiArrayTools MultiArray(const std::shared_ptr&... ranges, const std::vector& vec); MultiArray(const std::shared_ptr&... ranges, std::vector&& vec); + // Only if ALL ranges have default extensions: + //MultiArray(const std::vector& vec); + //MultiArray(std::vector&& vec); + // template // MultiArray(const MultiArray,Range3> in); @@ -716,6 +720,7 @@ namespace MultiArrayTools mCont.erase(mCont.begin() + MAB::mRange->size(), mCont.end()); } } + /* template template diff --git a/src/multi_range.h b/src/multi_range.h index 20e637e..85056ef 100644 --- a/src/multi_range.h +++ b/src/multi_range.h @@ -228,7 +228,7 @@ namespace MultiArrayTools MultiRange(const SpaceType& space); SpaceType mSpace; - + public: static const size_t sdim = sizeof...(Ranges); @@ -250,7 +250,8 @@ namespace MultiArrayTools virtual std::shared_ptr index() const override; friend MultiRangeFactory; - + + static const bool defaultable = false; }; } diff --git a/src/op_unit_test.cc b/src/op_unit_test.cc index fd2d938..67de8db 100644 --- a/src/op_unit_test.cc +++ b/src/op_unit_test.cc @@ -57,6 +57,12 @@ namespace { return std::make_tuple(ts...); } + template + auto mkts(Ts&&... ts) -> decltype(std::make_tuple(ts...)) + { + return std::make_tuple(static_cast( ts )...); + } + class OpTest_1Dim : public ::testing::Test { protected: @@ -177,46 +183,89 @@ namespace { { protected: - typedef SpinRangeFactory SRF; + typedef SpinRF SRF; typedef SpinRange SR; typedef MultiRangeFactory SR8F; typedef SR8F::oType SR8; static const size_t s = 65536; - OpTest_Spin() : data(s) + OpTest_Spin() { + data.resize(s); for(size_t i = 0; i != s; ++i){ double arg = static_cast( i - s ) - 0.1; data[i] = sin(arg)/arg; } - - swapFactory(rfbptr); - srptr = std::dynamic_pointer_cast( rfbptr->create() ); - /* - swapMFactory(rfbptr, srptr, srptr, srptr, srptr, - srptr, srptr, srptr, srptr); - mrptr = std::dynamic_pointer_cas( rfbptr->create() ); - */ + SRF f; + sr = std::dynamic_pointer_cast(f.create()); } - std::shared_ptr rfbptr; - std::shared_ptr srptr; - //std::shared_ptr mrptr; std::vector data; + std::shared_ptr sr; }; TEST_F(OpTest_Spin, Contract) { - MultiArray ma(srptr, srptr, srptr, srptr, - srptr, srptr, srptr, srptr, data); + MultiArray ma(sr, sr, sr, sr, sr, sr, sr, sr, data); + MultiArray res1( sr, sr ); + + auto alpha = MAT::getIndex(); + auto beta = MAT::getIndex(); + auto gamma = MAT::getIndex(); + auto delta = MAT::getIndex(); + auto deltap = MAT::getIndex(); + + auto mix = MAT::mkMIndex( alpha, beta, gamma ); - auto alpha = MAT::getIndex(srptr); - auto beta = MAT::getIndex(srptr); - auto gamma = MAT::getIndex(srptr); - auto delta = MAT::getIndex(srptr); + std::clock_t begin = std::clock(); + res1(delta, deltap) = ma(delta, alpha, alpha, beta, beta, gamma, gamma, deltap).c(mix); + std::clock_t end = std::clock(); + std::cout << "MultiArray time: " << static_cast( end - begin ) / CLOCKS_PER_SEC + << std::endl; + + std::vector vres(4*4); - // !!! + std::clock_t begin2 = std::clock(); + for(size_t d = 0; d != 4; ++d){ + for(size_t p = 0; p != 4; ++p){ + const size_t tidx = d*4 + p; + vres[tidx] = 0.; + for(size_t a = 0; a != 4; ++a){ + for(size_t b = 0; b != 4; ++b){ + for(size_t c = 0; c != 4; ++c){ + const size_t sidx = d*4*4*4*4*4*4*4 + a*5*4*4*4*4 + b*5*4*4*4 + c*5*4 + p; + vres[tidx] += data[sidx]; + } + } + } + } + } + std::clock_t end2 = std::clock(); + + EXPECT_EQ( xround(res1.at(mkts(0,0))), xround(vres[0]) ); + EXPECT_EQ( xround(res1.at(mkts(0,1))), xround(vres[1]) ); + EXPECT_EQ( xround(res1.at(mkts(0,2))), xround(vres[2]) ); + EXPECT_EQ( xround(res1.at(mkts(0,3))), xround(vres[3]) ); + + EXPECT_EQ( xround(res1.at(mkts(1,0))), xround(vres[4]) ); + EXPECT_EQ( xround(res1.at(mkts(1,1))), xround(vres[5]) ); + EXPECT_EQ( xround(res1.at(mkts(1,2))), xround(vres[6]) ); + EXPECT_EQ( xround(res1.at(mkts(1,3))), xround(vres[7]) ); + + EXPECT_EQ( xround(res1.at(mkts(2,0))), xround(vres[8]) ); + EXPECT_EQ( xround(res1.at(mkts(2,1))), xround(vres[9]) ); + EXPECT_EQ( xround(res1.at(mkts(2,2))), xround(vres[10]) ); + EXPECT_EQ( xround(res1.at(mkts(2,3))), xround(vres[11]) ); + + EXPECT_EQ( xround(res1.at(mkts(3,0))), xround(vres[12]) ); + EXPECT_EQ( xround(res1.at(mkts(3,1))), xround(vres[13]) ); + EXPECT_EQ( xround(res1.at(mkts(3,2))), xround(vres[14]) ); + EXPECT_EQ( xround(res1.at(mkts(3,3))), xround(vres[15]) ); + + std::cout << "std::vector - for loop time: " << static_cast( end2 - begin2 ) / CLOCKS_PER_SEC + << std::endl; + std::cout << "ratio: " << static_cast( end - begin ) / static_cast( end2 - begin2 ) << std::endl; } TEST_F(OpTest_Performance, PCheck) diff --git a/src/pack_num.h b/src/pack_num.h index 01ec623..16770c8 100644 --- a/src/pack_num.h +++ b/src/pack_num.h @@ -260,6 +260,13 @@ namespace MultiArrayHelper std::get(ip)->print(offset); PackNum::printIndex(ip, offset); } + + template + static void checkDefaultable() + { + static_assert( Range::defaultable, "not defaultable" ); + PackNum::template checkDefaultable(); + } }; template<> @@ -454,6 +461,12 @@ namespace MultiArrayHelper std::get<0>(ip)->print(offset); } + template + static void checkDefaultable() + { + static_assert( Range::defaultable, "not defaultable" ); + } + }; template diff --git a/src/range_types/header.h b/src/range_types/header.h index 6826c94..0632275 100644 --- a/src/range_types/header.h +++ b/src/range_types/header.h @@ -10,12 +10,14 @@ #ifdef __incl_this__ -#ifndef __ranges_header__ #define __ranges_header__ +//#ifndef __ranges_header__ +//#define __ranges_header__ #include "spin_range.h" -#endif +#undef __ranges_header__ +//#endif #endif #undef __incl_this__ diff --git a/src/range_types/spin_range.h b/src/range_types/spin_range.h index 209161a..73c990d 100644 --- a/src/range_types/spin_range.h +++ b/src/range_types/spin_range.h @@ -7,8 +7,8 @@ include_range_type(SPIN,2) #ifdef __ranges_header__ // assert, that this is only used within range_types/header.h -#ifndef __spin_range_h__ -#define __spin_range_h__ +//#ifndef __spin_range_h__ +//#define __spin_range_h__ namespace MultiArrayTools { @@ -19,10 +19,9 @@ namespace MultiArrayTools { public: - typedef SingleRange oType; + typedef SingleRange oType; - SingleRangeFactory() = delete; - SingleRangeFactory(size_t spinNum); // = 4 :) + SingleRangeFactory(); std::shared_ptr create(); }; @@ -32,7 +31,7 @@ namespace MultiArrayTools { public: typedef RangeBase RB; - typedef typename RangeInterface >::IndexType IndexType; + typedef typename RangeInterface >::IndexType IndexType; virtual size_t size() const override; virtual size_t dim() const override; @@ -42,18 +41,22 @@ namespace MultiArrayTools virtual IndexType begin() const override; virtual IndexType end() const override; - virtual std::shared_ptr index() const override; + virtual std::shared_ptr index() const override; friend SingleRangeFactory; + + static const bool defaultable = true; + static const size_t mSpinNum = 4; + + static SingleRangeFactory factory() + { return SingleRangeFactory(); } protected: SingleRange() = default; SingleRange(const SingleRange& in) = delete; - SingleRange(size_t spinNum); - - size_t mSpinNum = 4; + //SingleRange(size_t spinNum); }; typedef SingleRange SpinRange; @@ -70,14 +73,17 @@ namespace MultiArrayTools * SingleRange * ********************/ - SingleRangeFactory::SingleRangeFactory(const std::vector& space) + SingleRangeFactory::SingleRangeFactory() { - mProd = std::shared_ptr( new SingleRange( space ) ); + // Quasi Singleton + if(not mProd){ + mProd = std::shared_ptr( new SingleRange() ); + setSelf(); + } } std::shared_ptr SingleRangeFactory::create() { - setSelf(); return mProd; } @@ -85,12 +91,6 @@ namespace MultiArrayTools * SingleRange * ********************/ - SingleRange::SingleRange(size_t spinNum) : - RangeInterface >() - { - mSpinNum = spinNum; - } - size_t SingleRange::get(size_t pos) const { return pos; @@ -128,15 +128,18 @@ namespace MultiArrayTools } // put this in the interface class !!! - std::shared_ptr SingleRange::index() const + std::shared_ptr SingleRange::index() const { - return std::make_shared > - ( std::dynamic_pointer_cast > - ( std::shared_ptr( RB::mThis ) ) ); + typedef IndexWrapper IW; + return std::make_shared + ( std::make_shared + ( std::dynamic_pointer_cast > + ( std::shared_ptr( RB::mThis ) ) ) ); } + } -#endif // #ifndef __spin_range_h__ +//#endif // #ifndef __spin_range_h__ #endif // #ifdef __ranges_header__ diff --git a/src/single_range.h b/src/single_range.h index d8adf81..10e056c 100644 --- a/src/single_range.h +++ b/src/single_range.h @@ -161,6 +161,8 @@ namespace MultiArrayTools virtual std::shared_ptr index() const override; friend SingleRangeFactory; + + static const bool defaultable = false; protected: