first contraction test works

This commit is contained in:
Christian Zimmermann 2017-11-05 18:46:38 +01:00
parent 8dfa81a99e
commit 54dfcdb85d
6 changed files with 35 additions and 111 deletions

View file

@ -146,6 +146,10 @@ namespace MultiArrayTools
template <typename T, class OpFunction, class... Ops> template <typename T, class OpFunction, class... Ops>
class Operation; class Operation;
// multi_array_operation.h
template <typename T, class Op, class IndexType>
class Contraction;
/* /*
// multi_array_operation.h // multi_array_operation.h
template <typename T, class InRange, class TotalInRange, class OutRange, class TotalRange> template <typename T, class InRange, class TotalInRange, class OutRange, class TotalRange>

View file

@ -30,6 +30,7 @@ namespace MultiArrayHelper
void BlockBinaryOpSelf<T,OpFunc,BlockClass>::operator()(const BlockClass& arg) void BlockBinaryOpSelf<T,OpFunc,BlockClass>::operator()(const BlockClass& arg)
{ {
static OpFunc f; static OpFunc f;
if(mRes.size() == 0) { mRes.assign(arg.size(), static_cast<T>(0)); }
assert(mRes.size() == arg.size()); assert(mRes.size() == arg.size());
for(size_t i = 0; i != arg.size(); ++i){ for(size_t i = 0; i != arg.size(); ++i){
mRes[i] = f(mRes[i], arg[i]); mRes[i] = f(mRes[i], arg[i]);
@ -49,46 +50,7 @@ namespace MultiArrayHelper
{ {
return mSize; return mSize;
} }
/*
template <typename T>
template <class OpFunction>
BlockResult<T> BlockBase<T>::operate(const BlockBase<T>& in)
{
assert(mSize == in.size());
OpFunction f;
BlockResult<T> res(mSize);
CHECK;
for(size_t i = 0; i != mSize; ++i){
res[i] = f((*this)[i], in[i]);
}
return res;
}
template <typename T>
BlockResult<T> BlockBase<T>::operator+(const BlockBase<T>& in)
{
return operate<std::plus<T> >(in);
}
template <typename T>
BlockResult<T> BlockBase<T>::operator-(const BlockBase<T>& in)
{
return operate<std::minus<T> >(in);
}
template <typename T>
BlockResult<T> BlockBase<T>::operator*(const BlockBase<T>& in)
{
return operate<std::multiplies<T> >(in);
}
template <typename T>
BlockResult<T> BlockBase<T>::operator/(const BlockBase<T>& in)
{
return operate<std::divides<T> >(in);
}
*/
/************************ /************************
* MutableBlockBase * * MutableBlockBase *
************************/ ************************/
@ -244,49 +206,11 @@ namespace MultiArrayHelper
} }
template <typename T> template <typename T>
BlockResult<T>& BlockResult<T>::assing(const T& val) BlockResult<T>& BlockResult<T>::assign(size_t size, const T& val)
{ {
mRes.assing(BB::mSize, val); BB::mSize = size;
mRes.assign(BB::mSize, val);
return *this; return *this;
} }
/*
template <typename T>
BlockResult<T>& BlockResult<T>::operator+=(const BlockBase<T>& in)
{
return operateSelf<std::plus<T> >(in);
}
template <typename T>
BlockResult<T>& BlockResult<T>::operator-=(const BlockBase<T>& in)
{
return operateSelf<std::minus<T> >(in);
}
template <typename T>
BlockResult<T>& BlockResult<T>::operator*=(const BlockBase<T>& in)
{
return operateSelf<std::multiplies<T> >(in);
}
template <typename T>
BlockResult<T>& BlockResult<T>::operator/=(const BlockBase<T>& in)
{
return operateSelf<std::divides<T> >(in);
}
template <typename T>
template <class OpFunction>
BlockResult<T>& BlockResult<T>::operateSelf(const BlockBase<T>& in)
{
assert(BB::mSize == in.size());
OpFunction f;
//BlockResult<T> res(mSize);
for(size_t i = 0; i != BB::mSize; ++i){
(*this)[i] = f((*this)[i], in[i]);
}
return *this;
}
*/
} // end namespace MultiArrayHelper } // end namespace MultiArrayHelper

View file

@ -50,17 +50,9 @@ namespace MultiArrayHelper
BlockBase(size_t size); BlockBase(size_t size);
size_t size() const; size_t size() const;
/*
template <class OpFunction>
BlockResult<T> operate(const BlockBase& in);
BlockResult<T> operator+(const BlockBase& in);
BlockResult<T> operator-(const BlockBase& in);
BlockResult<T> operator*(const BlockBase& in);
BlockResult<T> operator/(const BlockBase& in);
*/
protected: protected:
size_t mSize; size_t mSize = 0;
}; };
template <typename T> template <typename T>
@ -136,7 +128,7 @@ namespace MultiArrayHelper
template <class BlockClass> template <class BlockClass>
BlockResult& operator=(const BlockClass& in); BlockResult& operator=(const BlockClass& in);
BlockResult& assing(const T& val); BlockResult& assign(size_t size, const T& val);
BlockType type() const; BlockType type() const;
const T& operator[](size_t pos) const; const T& operator[](size_t pos) const;
@ -144,14 +136,6 @@ namespace MultiArrayHelper
BlockResult& set(size_t npos); BlockResult& set(size_t npos);
size_t stepSize() const; size_t stepSize() const;
//BlockResult<T>& operator+=(const BlockBase<T>& in);
//BlockResult<T>& operator-=(const BlockBase<T>& in);
//BlockResult<T>& operator*=(const BlockBase<T>& in);
//BlockResult<T>& operator/=(const BlockBase<T>& in);
//template <class OpFunction>
//BlockResult<T>& operateSelf(const BlockBase<T>& in);
protected: protected:
std::vector<T> mRes; std::vector<T> mRes;
}; };

View file

@ -160,7 +160,7 @@ namespace MultiArrayTools
{ {
return Operation<T,std::divides<T>,OperationClass,Second>(*mOc, in); return Operation<T,std::divides<T>,OperationClass,Second>(*mOc, in);
} }
/*
template <typename T, class OperationClass> template <typename T, class OperationClass>
template <class IndexType> template <class IndexType>
auto OperationTemplate<T,OperationClass>::c(std::shared_ptr<IndexType>& ind) const auto OperationTemplate<T,OperationClass>::c(std::shared_ptr<IndexType>& ind) const
@ -168,7 +168,7 @@ namespace MultiArrayTools
{ {
return Contraction<T,OperationClass,IndexType>(*mOc, ind); return Contraction<T,OperationClass,IndexType>(*mOc, ind);
} }
*/
/************************* /*************************
* OperationMaster * * OperationMaster *
@ -324,7 +324,6 @@ namespace MultiArrayTools
const BlockResult<T>& Operation<T,OpFunction,Ops...>::get() const const BlockResult<T>& Operation<T,OpFunction,Ops...>::get() const
{ {
mRes = std::move( PackNum<sizeof...(Ops)-1>::template unpackArgs<T,OpFunction>(mOps) ); mRes = std::move( PackNum<sizeof...(Ops)-1>::template unpackArgs<T,OpFunction>(mOps) );
//CHECK;
return mRes; return mRes;
} }
@ -350,15 +349,14 @@ namespace MultiArrayTools
template <typename T, class Op, class IndexType> template <typename T, class Op, class IndexType>
Contraction<T,Op,IndexType>::Contraction(const Op& op, std::shared_ptr<IndexType> ind) : Contraction<T,Op,IndexType>::Contraction(const Op& op, std::shared_ptr<IndexType> ind) :
OperationTemplate<T,Contraction<T,Op,IndexType> >(this), OperationTemplate<T,Contraction<T,Op,IndexType> >(this),
mOp(op) {} mOp(op),
mInd(ind) {}
template <typename T, class Op, class IndexType> template <typename T, class Op, class IndexType>
const BlockResult<T>& Contraction<T,Op,IndexType>::get() const const BlockResult<T>& Contraction<T,Op,IndexType>::get() const
{ {
BlockBinaryOpSelf<T,std::plus<T>,BlockResult<T> > f(mRes); BlockBinaryOpSelf<T,std::plus<T>,BlockResult<T> > f(mRes);
mRes.assign( static_cast<T>(0) );
for(*mInd = 0; mInd->pos() != mInd->max(); ++(*mInd)){ for(*mInd = 0; mInd->pos() != mInd->max(); ++(*mInd)){
//mRes += mOp.get();
f(mOp.get()); f(mOp.get());
} }
return mRes; return mRes;
@ -367,9 +365,7 @@ namespace MultiArrayTools
template <typename T, class Op, class IndexType> template <typename T, class Op, class IndexType>
std::vector<BTSS> Contraction<T,Op,IndexType>::block(const std::shared_ptr<IndexBase> blockIndex) const std::vector<BTSS> Contraction<T,Op,IndexType>::block(const std::shared_ptr<IndexBase> blockIndex) const
{ {
std::vector<BTSS> btv; return mOp.block(blockIndex);
PackNum<0>::makeBlockTypeVec(btv, std::make_tuple( mOp ), blockIndex);
return btv;
} }
template <typename T, class Op, class IndexType> template <typename T, class Op, class IndexType>

View file

@ -108,11 +108,11 @@ namespace MultiArrayTools
template <class Second> template <class Second>
auto operator/(const Second& in) const auto operator/(const Second& in) const
-> Operation<T,std::divides<T>,OperationClass,Second>; -> Operation<T,std::divides<T>,OperationClass,Second>;
/*
template <class IndexType> template <class IndexType>
auto c(std::shared_ptr<IndexType>& ind) const auto c(std::shared_ptr<IndexType>& ind) const
-> Contraction<T,OperationClass,IndexType>; -> Contraction<T,OperationClass,IndexType>;
*/
private: private:
OperationClass* mOc; OperationClass* mOc;
}; };
@ -244,7 +244,7 @@ namespace MultiArrayTools
protected: protected:
const Op& mOp; Op mOp;
std::shared_ptr<IndexType> mInd; std::shared_ptr<IndexType> mInd;
mutable BlockResult<T> mRes; mutable BlockResult<T> mRes;
}; };

View file

@ -239,6 +239,22 @@ namespace {
} }
TEST_F(OpTest_MDim, ExecContract)
{
MultiArray<double,SRange> res(sr2ptr);
const MultiArray<double,SRange> ma1(sr2ptr, v1);
const MultiArray<double,SRange> ma2(sr4ptr, v2);
auto i1 = std::dynamic_pointer_cast<SRange::IndexType>( sr2ptr->index() );
auto i2 = std::dynamic_pointer_cast<SRange::IndexType>( sr4ptr->index() );
res(i1) = (ma1(i1) * ma2(i2)).c(i2);
EXPECT_EQ( xround( res.at('1') ), xround(2.917 * 8.870 + 2.917 * 4.790) );
EXPECT_EQ( xround( res.at('2') ), xround(9.436 * 8.870 + 9.436 * 4.790) );
EXPECT_EQ( xround( res.at('3') ), xround(0.373 * 8.870 + 0.373 * 4.790) );
}
TEST_F(OpTest_MDim, ExecOp2) TEST_F(OpTest_MDim, ExecOp2)
{ {
MultiArray<double,MRange,SRange> res(mr1ptr,sr4ptr); MultiArray<double,MRange,SRange> res(mr1ptr,sr4ptr);