dynamic contractions work

This commit is contained in:
Christian Zimmermann 2020-08-26 17:05:44 +02:00
parent 6444f971a6
commit 304d7ec682
2 changed files with 61 additions and 0 deletions

View file

@ -595,6 +595,8 @@ namespace MultiArrayTools
template <class ET>
inline Contraction& set(ET pos);
T* data() const { assert(0); return nullptr; }
auto rootSteps(std::intptr_t iPtrNum = 0) const // nullptr for simple usage with decltype
-> decltype(mOp.rootSteps(iPtrNum));

View file

@ -45,8 +45,10 @@ namespace
MultiArray<double,CR,DR> ma1;
MultiArray<double,CR,DR> ma2;
MultiArray<double,DR> ma3;
MultiArray<double,DR> ma5;
MultiArray<double,CR,DR> res1;
MultiArray<double,CR,DR> res2;
std::map<std::string,std::shared_ptr<IndexW>> imap;
@ -54,6 +56,8 @@ namespace
std::shared_ptr<DR> dr2;
std::shared_ptr<DR> dr3;
std::shared_ptr<DR> dr4;
std::shared_ptr<DR> dr5;
std::shared_ptr<DR> dr6;
std::shared_ptr<CR> cr1;
OpTest_Dyn()
@ -68,18 +72,23 @@ namespace
dr1 = createRangeE<DR>(cr2,cr2,cr3,cr4);
dr2 = createRangeE<DR>(cr3,cr3,cr4);
dr3 = createRangeE<DR>(cr2,cr5);
dr5 = createRangeE<DR>(cr5);
dr6 = createRangeE<DR>(cr3,cr4);
dr4 = createRangeE<DR>(cr2,cr3,cr4,cr4);
ma1 = mkArray<double>(cr1,dr1);
ma2 = mkArray<double>(cr1,dr2);
ma3 = mkArray<double>(dr3);
ma5 = mkArray<double>(dr5);
res1 = mkArray<double>(cr1,dr4);
res2 = mkArray<double>(cr1,dr6);
setMARandom(ma1, 25);
setMARandom(ma2, 31);
setMARandom(ma3, 47);
setMARandom(ma5, 59);
imap["i2_1"] = mkIndexW(getIndex(cr2));
imap["i2_2"] = mkIndexW(getIndex(cr2));
@ -145,6 +154,56 @@ namespace
TEST_F(OpTest_Dyn, Contract)
{
auto i1 = getIndex(cr1);
auto di1 = getIndex(dr1);
auto di3 = getIndex(dr3);
auto di5 = getIndex(dr5);
auto di6 = getIndex(dr6);
(*di1)({imap["i2_1"],imap["i2_1"],imap["i3_1"],imap["i4_1"]});
(*di3)({imap["i2_1"],imap["i5_1"]});
(*di5)({imap["i5_1"]});
(*di6)({imap["i3_1"],imap["i4_1"]});
auto resx1 = res2;
auto resx2 = res2;
auto resx3 = res2;
res2(i1,di6) += (ma1(i1,di1) * ma5(di5)).c(di3);
resx1(i1,di6) += (mkDynOp(ma1(i1,di1)) * mkDynOp(ma5(di5))).c(di3);
resx2(i1,di6) += mkDynOp((ma1(i1,di1) * ma5(di5)).c(di3));
resx3(i1,di6) += mkDynOp((mkDynOp(ma1(i1,di1)) * mkDynOp(ma5(di5))).c(di3));
auto i2_1 = imap.at("i2_1");
auto i3_1 = imap.at("i3_1");
auto i4_1 = imap.at("i4_1");
auto i5_1 = imap.at("i5_1");
for(size_t ii1 = 0; ii1 != i1->max(); ++ii1){
for(size_t ii3_1 = 0; ii3_1 != i3_1->max(); ++ii3_1){
for(size_t ii4_1 = 0; ii4_1 != i4_1->max(); ++ii4_1){
double vv = 0;
const size_t jr = (ii1*i3_1->max() + ii3_1)*i4_1->max() + ii4_1;
for(size_t ii2_1 = 0; ii2_1 != i2_1->max(); ++ii2_1){
const size_t j1 = (((ii1*i2_1->max() + ii2_1)*i2_1->max() + ii2_1)*i3_1->max() + ii3_1)*i4_1->max() + ii4_1;
for(size_t ii5_1 = 0; ii5_1 != i5_1->max(); ++ii5_1){
const size_t j2 = ii5_1;
vv += ma1.vdata()[j1] * ma5.vdata()[j2];
}
}
auto resv = xround(res2.vdata()[jr]);
auto resx1v = xround(resx1.vdata()[jr]);
auto resx2v = xround(resx2.vdata()[jr]);
auto resx3v = xround(resx3.vdata()[jr]);
auto x12 = xround(vv);
EXPECT_EQ( resv, x12 );
EXPECT_EQ( resx1v, x12 );
EXPECT_EQ( resx2v, x12 );
EXPECT_EQ( resx3v, x12 );
}
}
//std::cout << std::endl;
}
}