add first for tests

This commit is contained in:
Christian Zimmermann 2022-11-22 00:58:50 +01:00
parent 2f5f29f577
commit a3d3d99c8d
3 changed files with 227 additions and 23 deletions

View file

@ -26,25 +26,42 @@ namespace CNORXZ
template <class PosT>
inline decltype(auto) For<L,Xpr,F>::operator()(const PosT& last) const
{
typedef typename std::remove_reference<decltype(mXpr(last + mExt * UPos(0)))>::type OutT;
auto o = OutT();
for(SizeT i = 0; i != mSize; ++i){
const auto pos = last + mExt * UPos(i);
mF(o, mXpr(pos));
if constexpr(std::is_same<F,NoF>::value){
for(SizeT i = 0; i != mSize; ++i){
const auto pos = last + mExt * UPos(i);
mXpr(pos);
}
}
else {
typedef typename
std::remove_reference<decltype(mXpr(last + mExt * UPos(0)))>::type OutT;
auto o = OutT();
for(SizeT i = 0; i != mSize; ++i){
const auto pos = last + mExt * UPos(i);
mF(o, mXpr(pos));
}
return o;
}
return o;
}
template <SizeT L, class Xpr, class F>
inline decltype(auto) For<L,Xpr,F>::operator()() const
{
typedef typename std::remove_reference<decltype(mXpr(mExt * UPos(0)))>::type OutT;
auto o = OutT();
for(SizeT i = 0; i != mSize; ++i){
const auto pos = mExt * UPos(i);
mF(o, mXpr(pos));
if constexpr(std::is_same<F,NoF>::value){
for(SizeT i = 0; i != mSize; ++i){
const auto pos = mExt * UPos(i);
mXpr(pos);
}
}
else {
typedef typename std::remove_reference<decltype(mXpr(mExt * UPos(0)))>::type OutT;
auto o = OutT();
for(SizeT i = 0; i != mSize; ++i){
const auto pos = mExt * UPos(i);
mF(o, mXpr(pos));
}
return o;
}
return o;
}
template <SizeT L, class Xpr, class F>
@ -54,6 +71,21 @@ namespace CNORXZ
return mXpr.rootSteps(id);
}
/************************
* For (non-member) *
************************/
template <SizeT L, class Xpr, class F>
constexpr decltype(auto) mkFor(SizeT size, const IndexId<L>& id, const Xpr& xpr, F&& f)
{
return For<L,Xpr,F>(size, id, xpr, std::forward<F>(f));
}
template <SizeT L, class Xpr>
constexpr decltype(auto) mkFor(SizeT size, const IndexId<L>& id, const Xpr& xpr)
{
return For<L,Xpr>(size, id, xpr, NoF {});
}
/************
* SFor *
@ -63,7 +95,7 @@ namespace CNORXZ
constexpr SFor<N,L,Xpr,F>::SFor(const IndexId<L>& id, const Xpr& xpr, F&& f) :
mId(id),
mXpr(xpr),
mExt(mXpr.RootSteps(mId)),
mExt(mXpr.rootSteps(mId)),
mF(f)
{}
@ -71,13 +103,25 @@ namespace CNORXZ
template <class PosT>
constexpr decltype(auto) SFor<N,L,Xpr,F>::operator()(const PosT& last) const
{
return exec<0>(last);
if constexpr(std::is_same<F,NoF>::value){
exec2<0>(last);
return;
}
else {
return exec<0>(last);
}
}
template <SizeT N, SizeT L, class Xpr, class F>
constexpr decltype(auto) SFor<N,L,Xpr,F>::operator()() const
{
return exec<0>();
if constexpr(std::is_same<F,NoF>::value){
exec2<0>();
return;
}
else {
return exec<0>();
}
}
template <SizeT N, SizeT L, class Xpr, class F>
@ -93,12 +137,11 @@ namespace CNORXZ
{
constexpr SPos<I> i;
const auto pos = last + mExt * i;
auto o = mXpr(pos);
if constexpr(I < N-1){
return mF(o,exec<I+1>(last));
return mF(mXpr(pos),exec<I+1>(last));
}
else {
return o;
return mXpr(pos);
}
}
@ -108,15 +151,62 @@ namespace CNORXZ
{
constexpr SPos<I> i;
const auto pos = mExt * i;
auto o = mXpr(pos);
if constexpr(I < N-1){
return mF(o,exec<I+1>());
return mF(mXpr(pos),exec<I+1>());
}
else {
return o;
return mXpr(pos);
}
}
template <SizeT N, SizeT L, class Xpr, class F>
template <SizeT I, class PosT>
inline void SFor<N,L,Xpr,F>::exec2(const PosT& last) const
{
constexpr SPos<I> i;
const auto pos = last + mExt * i;
if constexpr(I < N-1){
mXpr(pos);
exec2<I+1>(last);
}
else {
mXpr(pos);
}
return;
}
template <SizeT N, SizeT L, class Xpr, class F>
template <SizeT I>
inline void SFor<N,L,Xpr,F>::exec2() const
{
constexpr SPos<I> i;
const auto pos = mExt * i;
if constexpr(I < N-1){
mXpr(pos);
exec2<I+1>();
}
else {
mXpr(pos);
}
return;
}
/*************************
* SFor (non-member) *
*************************/
template <SizeT N, SizeT L, class Xpr, class F>
constexpr decltype(auto) mkSFor(const IndexId<L>& id, const Xpr& xpr, F&& f)
{
return SFor<N,L,Xpr,F>(id, xpr, std::forward<F>(f));
}
template <SizeT N, SizeT L, class Xpr>
constexpr decltype(auto) mkSFor(const IndexId<L>& id, const Xpr& xpr)
{
return SFor<N,L,Xpr>(id, xpr, NoF {});
}
/************
* TFor *
************/

View file

@ -34,7 +34,12 @@ namespace CNORXZ
F mF;
};
template <SizeT L, class Xpr, class F>
constexpr decltype(auto) mkFor(SizeT size, const IndexId<L>& id, const Xpr& xpr, F&& f);
template <SizeT L, class Xpr>
constexpr decltype(auto) mkFor(SizeT size, const IndexId<L>& id, const Xpr& xpr);
// unrolled loop:
template <SizeT N, SizeT L, class Xpr, class F = NoF>
class SFor : public XprInterface<SFor<N,L,Xpr,F>>
@ -60,14 +65,25 @@ namespace CNORXZ
template <SizeT I>
constexpr decltype(auto) exec() const;
template <SizeT I, class PosT>
inline void exec2(const PosT& last) const;
template <SizeT I>
inline void exec2() const;
IndexId<L> mId;
Xpr mXpr;
typedef decltype(mXpr.RootSteps(mId)) XPosT;
typedef decltype(mXpr.rootSteps(mId)) XPosT;
XPosT mExt;
F mF;
};
template <SizeT N, SizeT L, class Xpr, class F>
constexpr decltype(auto) mkSFor(const IndexId<L>& id, const Xpr& xpr, F&& f);
template <SizeT N, SizeT L, class Xpr>
constexpr decltype(auto) mkSFor(const IndexId<L>& id, const Xpr& xpr);
// multi-threading
template <SizeT L, class Xpr, class F = NoF>

View file

@ -34,6 +34,73 @@ namespace
SPos<ss2> mS2p;
};
class For_Test : public ::testing::Test
{
protected:
class TestXpr1
{
public:
constexpr TestXpr1(const IndexId<0>& id) : mId(id) {}
template <class PosT>
inline SizeT operator()(const PosT& last) const
{
const SizeT o = 1u;
return o << last.val();
}
template <SizeT I>
inline UPos rootSteps(const IndexId<I>& id) const
{
return UPos( mId == id ? 1u : 0u );
}
private:
IndexId<0> mId;
};
class TestXpr2
{
public:
constexpr TestXpr2(const IndexId<0>& id, SizeT size) :
mId(id), mSize(size), mCnt(size) {}
template <class PosT>
inline void operator()(const PosT& last) const
{
--mCnt;
EXPECT_EQ(mCnt, mSize-last.val()-1);
}
template <SizeT I>
inline UPos rootSteps(const IndexId<I>& id) const
{
return UPos( mId == id ? 1u : 0u );
}
private:
IndexId<0> mId;
SizeT mSize;
mutable SizeT mCnt;
};
static constexpr SizeT sSize = 7u;
For_Test()
{
mSize = sSize;
mId1 = 10u;
mId2 = 11u;
mId3 = 12u;
}
SizeT mSize;
PtrId mId1;
PtrId mId2;
PtrId mId3;
};
TEST_F(Pos_Test, Basics)
{
EXPECT_EQ( mUp1.size(), 1u );
@ -157,6 +224,37 @@ namespace
EXPECT_EQ(dp5.sub().val(), mS4p.val() * mUp1.val());
}
TEST_F(For_Test, For)
{
auto loop = mkFor(mSize, IndexId<0>(mId1), TestXpr1( IndexId<0>(mId1) ),
[](auto& o, const auto& e) { o += e; });
const UPos rs = loop.rootSteps(IndexId<0>(mId1));
EXPECT_EQ(rs.val(), 1u);
const UPos rs2 = loop.rootSteps(IndexId<0>(mId2));
EXPECT_EQ(rs2.val(), 0u);
const SizeT res = loop();
EXPECT_EQ(res, (1u << mSize) - 1u);
auto loop2 = mkFor(mSize, IndexId<0>(mId1), TestXpr2( IndexId<0>(mId1), mSize ));
loop2();
}
TEST_F(For_Test, SFor)
{
auto loop = mkSFor<sSize>(IndexId<0>(mId1), TestXpr1( IndexId<0>(mId1) ),
[](const auto& a, const auto& b) { return a + b; });
const UPos rs = loop.rootSteps(IndexId<0>(mId1));
EXPECT_EQ(rs.val(), 1u);
const UPos rs2 = loop.rootSteps(IndexId<0>(mId2));
EXPECT_EQ(rs2.val(), 0u);
const SizeT res = loop();
EXPECT_EQ(res, (1u << mSize) - 1u);
auto loop2 = mkSFor<sSize>(IndexId<0>(mId1), TestXpr2( IndexId<0>(mId1), mSize ));
loop2();
}
}
int main(int argc, char** argv)