diff --git a/src/base_def.h b/src/base_def.h index 0e9723c..f181565 100644 --- a/src/base_def.h +++ b/src/base_def.h @@ -146,6 +146,10 @@ namespace MultiArrayTools template class Operation; + // multi_array_operation.h + template + class Contraction; + /* // multi_array_operation.h template diff --git a/src/block.cc b/src/block.cc index 113c5ba..a351545 100644 --- a/src/block.cc +++ b/src/block.cc @@ -30,6 +30,7 @@ namespace MultiArrayHelper void BlockBinaryOpSelf::operator()(const BlockClass& arg) { static OpFunc f; + if(mRes.size() == 0) { mRes.assign(arg.size(), static_cast(0)); } assert(mRes.size() == arg.size()); for(size_t i = 0; i != arg.size(); ++i){ mRes[i] = f(mRes[i], arg[i]); @@ -49,46 +50,7 @@ namespace MultiArrayHelper { return mSize; } - /* - template - template - BlockResult BlockBase::operate(const BlockBase& in) - { - assert(mSize == in.size()); - OpFunction f; - BlockResult res(mSize); - CHECK; - for(size_t i = 0; i != mSize; ++i){ - res[i] = f((*this)[i], in[i]); - } - return res; - } - - - template - BlockResult BlockBase::operator+(const BlockBase& in) - { - return operate >(in); - } - template - BlockResult BlockBase::operator-(const BlockBase& in) - { - return operate >(in); - } - - template - BlockResult BlockBase::operator*(const BlockBase& in) - { - return operate >(in); - } - - template - BlockResult BlockBase::operator/(const BlockBase& in) - { - return operate >(in); - } - */ /************************ * MutableBlockBase * ************************/ @@ -244,49 +206,11 @@ namespace MultiArrayHelper } template - BlockResult& BlockResult::assing(const T& val) + BlockResult& BlockResult::assign(size_t size, const T& val) { - mRes.assing(BB::mSize, val); + BB::mSize = size; + mRes.assign(BB::mSize, val); return *this; } - /* - template - BlockResult& BlockResult::operator+=(const BlockBase& in) - { - return operateSelf >(in); - } - - template - BlockResult& BlockResult::operator-=(const BlockBase& in) - { - return operateSelf >(in); - } - - template - BlockResult& BlockResult::operator*=(const BlockBase& in) - { - return operateSelf >(in); - } - - template - BlockResult& BlockResult::operator/=(const BlockBase& in) - { - return operateSelf >(in); - } - - template - template - BlockResult& BlockResult::operateSelf(const BlockBase& in) - { - assert(BB::mSize == in.size()); - OpFunction f; - //BlockResult res(mSize); - for(size_t i = 0; i != BB::mSize; ++i){ - (*this)[i] = f((*this)[i], in[i]); - } - return *this; - } - */ - } // end namespace MultiArrayHelper diff --git a/src/block.h b/src/block.h index 468e6b6..ffb6610 100644 --- a/src/block.h +++ b/src/block.h @@ -50,17 +50,9 @@ namespace MultiArrayHelper BlockBase(size_t size); size_t size() const; - /* - template - BlockResult operate(const BlockBase& in); - - BlockResult operator+(const BlockBase& in); - BlockResult operator-(const BlockBase& in); - BlockResult operator*(const BlockBase& in); - BlockResult operator/(const BlockBase& in); - */ + protected: - size_t mSize; + size_t mSize = 0; }; template @@ -136,7 +128,7 @@ namespace MultiArrayHelper template BlockResult& operator=(const BlockClass& in); - BlockResult& assing(const T& val); + BlockResult& assign(size_t size, const T& val); BlockType type() const; const T& operator[](size_t pos) const; @@ -144,14 +136,6 @@ namespace MultiArrayHelper BlockResult& set(size_t npos); size_t stepSize() const; - //BlockResult& operator+=(const BlockBase& in); - //BlockResult& operator-=(const BlockBase& in); - //BlockResult& operator*=(const BlockBase& in); - //BlockResult& operator/=(const BlockBase& in); - - //template - //BlockResult& operateSelf(const BlockBase& in); - protected: std::vector mRes; }; diff --git a/src/multi_array_operation.cc b/src/multi_array_operation.cc index 6e12a72..647e882 100644 --- a/src/multi_array_operation.cc +++ b/src/multi_array_operation.cc @@ -160,7 +160,7 @@ namespace MultiArrayTools { return Operation,OperationClass,Second>(*mOc, in); } - /* + template template auto OperationTemplate::c(std::shared_ptr& ind) const @@ -168,7 +168,7 @@ namespace MultiArrayTools { return Contraction(*mOc, ind); } - */ + /************************* * OperationMaster * @@ -324,7 +324,6 @@ namespace MultiArrayTools const BlockResult& Operation::get() const { mRes = std::move( PackNum::template unpackArgs(mOps) ); - //CHECK; return mRes; } @@ -350,15 +349,14 @@ namespace MultiArrayTools template Contraction::Contraction(const Op& op, std::shared_ptr ind) : OperationTemplate >(this), - mOp(op) {} + mOp(op), + mInd(ind) {} template const BlockResult& Contraction::get() const { BlockBinaryOpSelf,BlockResult > f(mRes); - mRes.assign( static_cast(0) ); for(*mInd = 0; mInd->pos() != mInd->max(); ++(*mInd)){ - //mRes += mOp.get(); f(mOp.get()); } return mRes; @@ -367,9 +365,7 @@ namespace MultiArrayTools template std::vector Contraction::block(const std::shared_ptr blockIndex) const { - std::vector btv; - PackNum<0>::makeBlockTypeVec(btv, std::make_tuple( mOp ), blockIndex); - return btv; + return mOp.block(blockIndex); } template diff --git a/src/multi_array_operation.h b/src/multi_array_operation.h index 1a9bc39..6e17fa0 100644 --- a/src/multi_array_operation.h +++ b/src/multi_array_operation.h @@ -108,11 +108,11 @@ namespace MultiArrayTools template auto operator/(const Second& in) const -> Operation,OperationClass,Second>; - /* + template auto c(std::shared_ptr& ind) const -> Contraction; - */ + private: OperationClass* mOc; }; @@ -244,7 +244,7 @@ namespace MultiArrayTools protected: - const Op& mOp; + Op mOp; std::shared_ptr mInd; mutable BlockResult mRes; }; diff --git a/src/op_unit_test.cc b/src/op_unit_test.cc index 1ca0767..5c38d68 100644 --- a/src/op_unit_test.cc +++ b/src/op_unit_test.cc @@ -239,6 +239,22 @@ namespace { } + TEST_F(OpTest_MDim, ExecContract) + { + MultiArray res(sr2ptr); + const MultiArray ma1(sr2ptr, v1); + const MultiArray ma2(sr4ptr, v2); + + auto i1 = std::dynamic_pointer_cast( sr2ptr->index() ); + auto i2 = std::dynamic_pointer_cast( sr4ptr->index() ); + + res(i1) = (ma1(i1) * ma2(i2)).c(i2); + + EXPECT_EQ( xround( res.at('1') ), xround(2.917 * 8.870 + 2.917 * 4.790) ); + EXPECT_EQ( xround( res.at('2') ), xround(9.436 * 8.870 + 9.436 * 4.790) ); + EXPECT_EQ( xround( res.at('3') ), xround(0.373 * 8.870 + 0.373 * 4.790) ); + } + TEST_F(OpTest_MDim, ExecOp2) { MultiArray res(mr1ptr,sr4ptr);