diff --git a/src/include/multi_array_operation.cc.h b/src/include/multi_array_operation.cc.h index b849726..b9d0207 100644 --- a/src/include/multi_array_operation.cc.h +++ b/src/include/multi_array_operation.cc.h @@ -519,66 +519,91 @@ namespace MultiArrayTools } template - template - auto OperationRoot::assign(const OpClass& in) const - -> decltype(mIndex.ifor(1,in.loop(AssignmentExpr2,OpClass,OpIndexAff::TARGET> + template + auto OperationRoot::asx(const OpClass& in) const + -> decltype(mIndex.ifor(1,in.loop(AssignmentExpr,OpClass,OpIndexAff::TARGET> (mOrigDataPtr,*this,in)))) { static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); - return mIndex.ifor(1,in.loop(AssignmentExpr2,OpClass,OpIndexAff::TARGET> + return mIndex.ifor(1,in.loop(AssignmentExpr,OpClass,OpIndexAff::TARGET> (mOrigDataPtr,*this,in))); } + template + template + auto OperationRoot::asxExpr(const OpClass& in) const + -> decltype(in.loop(AssignmentExpr,OpClass> + (mOrigDataPtr,*this,in))) + { + static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); + return in.loop(AssignmentExpr,OpClass> + (mOrigDataPtr,*this,in)); + } + + template + template + auto OperationRoot::asx(const OpClass& in, const std::shared_ptr& i) const + -> decltype(i->ifor(1,in.loop(AssignmentExpr,OpClass> + (mOrigDataPtr,*this,in)))) + { + static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); + return i->ifor(1,in.loop(AssignmentExpr,OpClass> + (mOrigDataPtr,*this,in))); + } + + template + template + auto OperationRoot::assign(const OpClass& in) const + -> decltype(this->template asx(in)) + { + return this->template asx(in); + } + template template auto OperationRoot::assignExpr(const OpClass& in) const - -> decltype(in.loop(AssignmentExpr2,OpClass> - (mOrigDataPtr,*this,in))) + -> decltype(this->template asxExpr(in)) { - static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); - return in.loop(AssignmentExpr2,OpClass> - (mOrigDataPtr,*this,in)); + return this->template asxExpr(in); } template template auto OperationRoot::assign(const OpClass& in, const std::shared_ptr& i) const - -> decltype(i->ifor(1,in.loop(AssignmentExpr2,OpClass> - (mOrigDataPtr,*this,in)))) + -> decltype(this->template asx(in,i)) { - static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); - return i->ifor(1,in.loop(AssignmentExpr2,OpClass> - (mOrigDataPtr,*this,in))); + return this->template asx(in,i); } template template auto OperationRoot::plus(const OpClass& in) const - -> decltype(mIndex.ifor(1,in.loop(AddExpr,OpClass,OpIndexAff::TARGET> - (mOrigDataPtr,*this,in)))) + -> decltype(this->template asx(in)) { - static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); - return mIndex.ifor(1,in.loop(AddExpr,OpClass,OpIndexAff::TARGET> - (mOrigDataPtr,*this,in))); + return this->template asx(in); } template template auto OperationRoot::plus(const OpClass& in, const std::shared_ptr& i) const - -> decltype(i->ifor(1,in.loop(AddExpr,OpClass> - (mOrigDataPtr,*this,in)))) + -> decltype(this->template asx(in,i)) { - static_assert( OpClass::SIZE == decltype(in.rootSteps())::SIZE, "Ext Size mismatch" ); - return i->ifor(1,in.loop(AddExpr,OpClass> - (mOrigDataPtr,*this,in))); + return this->template asx(in,i); } template template OperationRoot& OperationRoot::operator=(const OpClass& in) { - assign(in)(); + auto x = this->template asx::type>>(in); + const size_t inum = x.vec(VType::MULT); + if(x.rootSteps(inum) == 1){ + x(); + } + else { + assign(in)(); + } return *this; } @@ -613,7 +638,7 @@ namespace MultiArrayTools template inline V& OperationRoot::vget(ET pos) const { - return *(reinterpret_cast(mDataPtr)+pos.val()); + return *(reinterpret_cast(mDataPtr)+pos.val()); } template diff --git a/src/include/multi_array_operation.h b/src/include/multi_array_operation.h index 2d3787d..53e284b 100644 --- a/src/include/multi_array_operation.h +++ b/src/include/multi_array_operation.h @@ -538,29 +538,39 @@ namespace MultiArrayTools OperationRoot(T* data, const IndexType& ind); - template - auto assign(const OpClass& in) const - -> decltype(mIndex.ifor(1,in.loop(AssignmentExpr2,OpClass,OpIndexAff::TARGET> + template + auto asx(const OpClass& in) const + -> decltype(mIndex.ifor(1,in.loop(AssignmentExpr,OpClass,OpIndexAff::TARGET> (mOrigDataPtr,*this,in)))); - template - auto assignExpr(const OpClass& in) const - -> decltype(in.loop(AssignmentExpr2,OpClass>(mOrigDataPtr,*this,in))); + template + auto asxExpr(const OpClass& in) const + -> decltype(in.loop(AssignmentExpr,OpClass>(mOrigDataPtr,*this,in))); - template - auto assign(const OpClass& in, const std::shared_ptr& i) const - -> decltype(i->ifor(1,in.loop(AssignmentExpr2,OpClass> + template + auto asx(const OpClass& in, const std::shared_ptr& i) const + -> decltype(i->ifor(1,in.loop(AssignmentExpr,OpClass> (mOrigDataPtr,*this,in)))); template + auto assign(const OpClass& in) const + -> decltype(this->template asx(in)); + + template + auto assignExpr(const OpClass& in) const + -> decltype(this->template asxExpr(in)); + + template + auto assign(const OpClass& in, const std::shared_ptr& i) const + -> decltype(this->template asx(in,i)); + + template auto plus(const OpClass& in) const - -> decltype(mIndex.ifor(1,in.loop(AddExpr,OpClass,OpIndexAff::TARGET> - (mOrigDataPtr,*this,in)))); + -> decltype(this->template asx(in)); template auto plus(const OpClass& in, const std::shared_ptr& i) const - -> decltype(i->ifor(1,in.loop(AddExpr,OpClass> - (mOrigDataPtr,*this,in)))); + -> decltype(this->template asx(in,i)); template OperationRoot& operator=(const OpClass& in); diff --git a/src/include/pack_num.h b/src/include/pack_num.h index 7b6af61..234e6bf 100644 --- a/src/include/pack_num.h +++ b/src/include/pack_num.h @@ -88,7 +88,7 @@ namespace MultiArrayHelper static_assert(LAST >= NextOpType::SIZE, "inconsistent array positions"); static constexpr size_t NEXT = LAST - NextOpType::SIZE; typedef decltype(std::get(ops).template vget(Getter::template getX( pos ))) ArgT; - return PackNum::template mkVOpExpr + return PackNum::template mkVOpExpr ( f, pos, ops, std::get(ops).template vget(Getter::template getX( pos )), args...); } diff --git a/src/include/xfor/xfor.h b/src/include/xfor/xfor.h index 292b666..1ee58f5 100644 --- a/src/include/xfor/xfor.h +++ b/src/include/xfor/xfor.h @@ -35,6 +35,9 @@ namespace MultiArrayHelper virtual size_t size() const = 0; virtual const size_t& val() const = 0; //virtual size_t rootSteps() const = 0; + virtual bool operator==(const ExtBase& in) const = 0; + virtual bool operator==(size_t in) const = 0; + virtual std::shared_ptr operator+(const ExtBase& in) const = 0; virtual std::shared_ptr operator*(size_t in) const = 0; virtual void zero() = 0; @@ -75,6 +78,12 @@ namespace MultiArrayHelper virtual const size_t& val() const override final { return mExt.val(); } virtual void zero() override final { mExt.zero(); } + virtual bool operator==(const ExtBase& in) const override final + { return mExt == dynamic_cast&>(in).mExt; } + + virtual bool operator==(size_t in) const override final + { return mExt == in; } + virtual DExt operator+(const ExtBase& in) const override final { return std::make_shared>( mExt + dynamic_cast&>(in).mExt ); } virtual DExt operator*(size_t in) const override final @@ -122,6 +131,11 @@ namespace MultiArrayHelper template DExtTX(const Y& y) : mDExt(std::make_shared>(y)) {} */ + bool operator==(const DExtTX& in) const + { return *mDExt == *in.mDExt and mNext == in.mNext; } + + bool operator==(size_t in) const + { return *mDExt == in and mNext == in; } template DExtTX(const DExtTX& in) : mDExt(in.mDExt), mNext(in.mNext) {}