diff --git a/src/include/expressions.h b/src/include/expressions.h index dd8ba3c..2d40b38 100644 --- a/src/include/expressions.h +++ b/src/include/expressions.h @@ -34,6 +34,12 @@ namespace MultiArrayTools template class OpF, class... MAs> using AEX = AEXT>; + + template + using AEX_M = AEXT>>; + + template + using AEX_C = AEXT>>; template class OpF> using AEX_B_MM = AEX,DDMMA>; @@ -85,7 +91,10 @@ namespace MultiArrayTools V_IFOR_X(AEX_B_MC); \ V_IFOR_X(AEX_B_CM); \ V_IFOR_X(AEX_B_CC) - + +#define V_IFOR_A_1(EC) \ + V_IFOR_X(AEX_M); \ + V_IFOR_X(AEX_C) template class E1; @@ -107,7 +116,8 @@ namespace MultiArrayTools V_IFOR_A(EX,MultiArrayTools::minus); V_IFOR_A(EX,MultiArrayTools::multiplies); V_IFOR_A(EX,MultiArrayTools::divides); - + V_IFOR_A_1(EX); + public: template inline MultiArrayTools::ExpressionHolder ifor(size_t step, MultiArrayTools::ExpressionHolder ex) const; @@ -140,6 +150,10 @@ namespace MultiArrayTools D_IFOR_X(AEX_B_CM,Ind); \ D_IFOR_X(AEX_B_CC,Ind) +#define D_IFOR_A_1(EC,Ind) \ + D_IFOR_X(AEX_M,Ind); \ + D_IFOR_X(AEX_C,Ind) + template class E1 : public Expressions1 { @@ -155,7 +169,8 @@ namespace MultiArrayTools D_IFOR_A(EX,MultiArrayTools::minus,mI); D_IFOR_A(EX,MultiArrayTools::multiplies,mI); D_IFOR_A(EX,MultiArrayTools::divides,mI); - + D_IFOR_A_1(EX,mI); + public: E1(const E1& in) = default; E1(E1&& in) = default; diff --git a/src/tests/op_unit_test.cc b/src/tests/op_unit_test.cc index 1321cd7..5917d02 100644 --- a/src/tests/op_unit_test.cc +++ b/src/tests/op_unit_test.cc @@ -654,10 +654,12 @@ namespace { MultiArray ma1(mr1ptr,sr4ptr,v5); MultiArray ma2(sr2ptr,v1); MultiArray res(sr4ptr,mr1ptr); + MultiArray res2(mr1ptr,sr4ptr); DMA dma1 = *std::dynamic_pointer_cast( dynamic( ma1 ) ); DMA dma2 = *std::dynamic_pointer_cast( dynamic( ma2 ) ); DMA dres = *std::dynamic_pointer_cast( dynamic( res ) ); + DMA dres2 = *std::dynamic_pointer_cast( dynamic( res2 ) ); auto si2 = MAT::getIndex( sr2ptr ); auto si3 = MAT::getIndex( sr3ptr ); @@ -668,29 +670,49 @@ namespace { auto di1 = MAT::getIndex( MAT::rptr<0>( dma1 ) ); auto di2 = MAT::getIndex( MAT::rptr<0>( dma2 ) ); auto dir = MAT::getIndex( MAT::rptr<0>( dres ) ); + //auto dirx = MAT::getIndex( MAT::rptr<0>( dres ) ); + auto dir2 = MAT::getIndex( MAT::rptr<0>( dres2 ) ); (*di1)(mi,si4); (*di2)(si2); (*dir)(si4,mi); + (*dir2)(mi,si4); dres(dir) = dma1(di1) + dma2(di2); - - MultiArray res2(sr4ptr,mr1ptr,dres.vdata()); + res = dres.format(sr4ptr,mr1ptr); - EXPECT_EQ( xround( res2.at( mkt('A',mkt('1','a')) ) ), xround( 30.932 + 2.917 ) ); - EXPECT_EQ( xround( res2.at( mkt('A',mkt('1','b')) ) ), xround( -26.205 + 2.917 ) ); - EXPECT_EQ( xround( res2.at( mkt('A',mkt('2','a')) ) ), xround( 21.227 + 9.436 ) ); - EXPECT_EQ( xround( res2.at( mkt('A',mkt('2','b')) ) ), xround( -14.364 + 9.436 ) ); - EXPECT_EQ( xround( res2.at( mkt('A',mkt('3','a')) ) ), xround( -25.703 + 0.373 ) ); - EXPECT_EQ( xround( res2.at( mkt('A',mkt('3','b')) ) ), xround( 23.563 + 0.373 ) ); + EXPECT_EQ( xround( res.at( mkt('A',mkt('1','a')) ) ), xround( 30.932 + 2.917 ) ); + EXPECT_EQ( xround( res.at( mkt('A',mkt('1','b')) ) ), xround( -26.205 + 2.917 ) ); + EXPECT_EQ( xround( res.at( mkt('A',mkt('2','a')) ) ), xround( 21.227 + 9.436 ) ); + EXPECT_EQ( xround( res.at( mkt('A',mkt('2','b')) ) ), xround( -14.364 + 9.436 ) ); + EXPECT_EQ( xround( res.at( mkt('A',mkt('3','a')) ) ), xround( -25.703 + 0.373 ) ); + EXPECT_EQ( xround( res.at( mkt('A',mkt('3','b')) ) ), xround( 23.563 + 0.373 ) ); - EXPECT_EQ( xround( res2.at( mkt('B',mkt('1','a')) ) ), xround( -33.693 + 2.917 ) ); - EXPECT_EQ( xround( res2.at( mkt('B',mkt('1','b')) ) ), xround( -15.504 + 2.917 ) ); - EXPECT_EQ( xround( res2.at( mkt('B',mkt('2','a')) ) ), xround( 17.829 + 9.436 ) ); - EXPECT_EQ( xround( res2.at( mkt('B',mkt('2','b')) ) ), xround( -1.868 + 9.436 ) ); - EXPECT_EQ( xround( res2.at( mkt('B',mkt('3','a')) ) ), xround( 13.836 + 0.373 ) ); - EXPECT_EQ( xround( res2.at( mkt('B',mkt('3','b')) ) ), xround( 41.339 + 0.373 ) ); + EXPECT_EQ( xround( res.at( mkt('B',mkt('1','a')) ) ), xround( -33.693 + 2.917 ) ); + EXPECT_EQ( xround( res.at( mkt('B',mkt('1','b')) ) ), xround( -15.504 + 2.917 ) ); + EXPECT_EQ( xround( res.at( mkt('B',mkt('2','a')) ) ), xround( 17.829 + 9.436 ) ); + EXPECT_EQ( xround( res.at( mkt('B',mkt('2','b')) ) ), xround( -1.868 + 9.436 ) ); + EXPECT_EQ( xround( res.at( mkt('B',mkt('3','a')) ) ), xround( 13.836 + 0.373 ) ); + EXPECT_EQ( xround( res.at( mkt('B',mkt('3','b')) ) ), xround( 41.339 + 0.373 ) ); + + //dres = *std::dynamic_pointer_cast( dynamic( res ) ); + dres2(dir2) = dres(dir); + res2 = dres2.format(mr1ptr,sr4ptr); + EXPECT_EQ( xround( res2.at( mkt(mkt('1','a'),'A') ) ), xround( 30.932 + 2.917 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('1','b'),'A') ) ), xround( -26.205 + 2.917 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('2','a'),'A') ) ), xround( 21.227 + 9.436 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('2','b'),'A') ) ), xround( -14.364 + 9.436 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('3','a'),'A') ) ), xround( -25.703 + 0.373 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('3','b'),'A') ) ), xround( 23.563 + 0.373 ) ); + + EXPECT_EQ( xround( res2.at( mkt(mkt('1','a'),'B') ) ), xround( -33.693 + 2.917 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('1','b'),'B') ) ), xround( -15.504 + 2.917 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('2','a'),'B') ) ), xround( 17.829 + 9.436 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('2','b'),'B') ) ), xround( -1.868 + 9.436 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('3','a'),'B') ) ), xround( 13.836 + 0.373 ) ); + EXPECT_EQ( xround( res2.at( mkt(mkt('3','b'),'B') ) ), xround( 41.339 + 0.373 ) ); + } TEST_F(OpTest_MDim, ExecOp3)