#cython: language_level=3
#cython: c_string_type=str
#cython: c_string_encoding=ascii
#cython: boundscheck=False
#cython: wrapparound=False
#cython: cdivision=True
#cython: embedsignature=True

from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
from libcpp.string cimport string
from libcpp cimport bool
from cython.operator cimport dereference as deref
from range cimport cpp_RangeBase
from range_factory cimport cpp_RangeFactoryBase, cpp_CRangeFactory
from array cimport cpp_CArrayBase, cpp_MArray, cpp_AIndex
from index cimport cpp_DIndex
from cereal cimport cpp_writeJSONFile, cpp_writeBINARYFile, cpp_readJSONFile, cpp_readBINARYFile


## ============
##    Range
## ============

cdef class Range:
    cdef shared_ptr[cpp_RangeBase] cpp_range

    def size(self):
        return self.cpp_range.get().size()

    def dim(self):
        return self.cpp_range.get().dim()

    def sub(self):
        return getSubRange(self)

    def index(self):
        return getRangeIndex(self)


cdef class RangeFactory:
    cdef shared_ptr[cpp_RangeFactoryBase] cpp_rfactory

    def __cinit__(self,rangetype,**kvargs):
        cdef size_t size = 0
        if rangetype == 'C':
            size = kvargs['size']
            self.cpp_rfactory = dynamic_pointer_cast[cpp_RangeFactoryBase,cpp_CRangeFactory](
	        make_shared[cpp_CRangeFactory](size))
        else:
            raise Exception('unknown range type:'+ rangetype)

    def create(self):
        r = Range()
        r.cpp_range = self.cpp_rfactory.get().create()
        return r


## ===========
##    Index   
## ===========

cdef class Index:

    def __iter__(self):
        return self
        
    def __next__(self):
        return self

    def lex(self):
        return 0

    def dim(self):
        return 0

    def stringMeta(self):
        return ''

    
cdef class DIndex (Index):
    cdef shared_ptr[cpp_DIndex] cpp_index
    cdef bool itercall
    
    def __cinit__(self,_range,_lexpos=0):
        cdef Range r = _range
        cdef size_t l = _lexpos
        self.cpp_index = make_shared[cpp_DIndex] (r.cpp_range,l)
        self.itercall = False

    def __iter__(self):
        self.cpp_index.get().setlpos(0)
        self.itercall = True # otherwise first (i.e. zeroth) will be excluded
        return self

    def __next__(self):
        cdef DIndex ret = self
        if self.itercall:
            ret.itercall = False
            return ret
        if self.lex() < self.cpp_index.get().range().get().size()-1:
            ret.cpp_index = make_shared[cpp_DIndex] (self.cpp_index.get().plus(1))
            return ret
        else:
            raise StopIteration

    def lex(self):
        return self.cpp_index.get().lex()

    def dim(self):
        return self.cpp_index.get().dim()

    def stringMeta(self):
        return self.cpp_index.get().stringMeta()

def getRangeIndex(_range):
    return DIndex(_range)


## ===========
##   Array
## ===========
    
cdef class Array:

    def size(self):
        return 0

    def range(self):
        cdef Range r = Range()
        return r
    

cdef class Array_Double (Array):
    cdef shared_ptr[cpp_CArrayBase[double]] cpp_array

    def __cinit__(self,_range=None):
        cdef Range r
        if not _range is None:
            r = _range
            self.cpp_array = dynamic_pointer_cast[cpp_CArrayBase[double],cpp_MArray[double]] (make_shared[cpp_MArray[double]] (r.cpp_range) )

    def size(self):
        return self.cpp_array.get().size()

    def range(self):
        cdef Range r = Range()
        r.cpp_range = self.cpp_array.get().range()
        return r

    def index(self):
        return AIndex_Double(self)

cdef class Array_Range (Array):
    cdef shared_ptr[cpp_CArrayBase[shared_ptr[cpp_RangeBase]]] cpp_array

    def __cinit__(self,_range=None):
        cdef Range r
        if not _range is None:
            r = _range
            self.cpp_array = dynamic_pointer_cast[cpp_CArrayBase[shared_ptr[cpp_RangeBase]],cpp_MArray[shared_ptr[cpp_RangeBase]]] (make_shared[cpp_MArray[shared_ptr[cpp_RangeBase]]] (r.cpp_range) )

    def size(self):
        return self.cpp_array.get().size()

    def range(self):
        cdef Range r = Range()
        r.cpp_range = self.cpp_array.get().range()
        return r

    def index(self):
        return AIndex_Range(self)

    
def getSubRange(_range):
    cdef Range r = _range
    cdef Array_Range a = Array_Range()
    a.cpp_array = dynamic_pointer_cast[cpp_CArrayBase[shared_ptr[cpp_RangeBase]],cpp_MArray[shared_ptr[cpp_RangeBase]]] (make_shared[cpp_MArray[shared_ptr[cpp_RangeBase]]] (r.cpp_range.get().sub()) )
    return a

def writeFile(_fname,_array,_format):
    cdef Array_Double a = _array
    cdef shared_ptr[cpp_MArray[double]] ap = dynamic_pointer_cast[cpp_MArray[double],cpp_CArrayBase[double]](a.cpp_array)
    if _format.upper() == "JSON":
        cpp_writeJSONFile[double](_fname, deref(ap.get()))
    elif _format.upper() == "BINARY":
        cpp_writeBINARYFile[double](_fname, deref(ap.get()))
    else:
        raise Exception("unknown array file format '{}'".format(_format))

def readFile(_fname,_format):
    cdef Array_Double a = Array_Double()
    cdef shared_ptr[cpp_MArray[double]] ap = make_shared[cpp_MArray[double]]()
    a.cpp_array = dynamic_pointer_cast[cpp_CArrayBase[double],cpp_MArray[double]](ap)
    if _format.upper() == "JSON":
        cpp_readJSONFile[double](_fname, deref(ap.get()))
    elif _format.upper() == "BINARY":
        cpp_readBINARYFile[double](_fname, deref(ap.get()))
    else:
        raise Exception("unknown array file format '{}'".format(_format))
    return a


cdef class AIndex_Double (Index):
    cdef shared_ptr[cpp_CArrayBase[double]] cpp_array # keep the instance alive
    cdef shared_ptr[cpp_AIndex[double]] cpp_index
    cdef bool itercall

    def __cinit__(self,_array,_lexpos=0):
        cdef size_t l = _lexpos
        cdef Array_Double a = _array
        self.cpp_array = a.cpp_array
        self.cpp_index = make_shared[cpp_AIndex[double]] (a.cpp_array.get().begin().A_plus(l))
        self.itercall = False

    def __iter__(self):
        self.cpp_index.get().setlpos(0)
        self.itercall = True # otherwise first (i.e. zeroth) will be excluded
        return self

    def __next__(self):
        cdef AIndex_Double ret = self
        if self.itercall:
            ret.itercall = False
            return ret
        if self.lex() < self.cpp_index.get().range().get().size()-1:
            ret.cpp_index = make_shared[cpp_AIndex[double]] (self.cpp_index.get().A_plus(1))
            return ret
        else:
            raise StopIteration

    def lex(self):
        return self.cpp_index.get().lex()

    def dim(self):
        return self.cpp_index.get().dim()

    def stringMeta(self):
        return self.cpp_index.get().stringMeta()
    
    def get(self):
        return self.cpp_index.get().A_get()

    
cdef class AIndex_Range (Index):
    cdef shared_ptr[cpp_CArrayBase[shared_ptr[cpp_RangeBase]]] cpp_array # keep the instance alive
    cdef shared_ptr[cpp_AIndex[shared_ptr[cpp_RangeBase]]] cpp_index
    cdef bool itercall

    def __cinit__(self,_array,_lexpos=0):
        cdef size_t l = _lexpos
        cdef Array_Range a = _array
        self.cpp_array = a.cpp_array
        self.cpp_index = make_shared[cpp_AIndex[shared_ptr[cpp_RangeBase]]] (a.cpp_array.get().begin().A_plus(l))
        self.itercall = False

    def __iter__(self):
        self.cpp_index.get().setlpos(0)
        self.itercall = True # otherwise first (i.e. zeroth) will be excluded
        return self

    def __next__(self):
        cdef AIndex_Range ret = self
        if self.itercall:
            ret.itercall = False
            return ret
        if self.lex() < self.cpp_index.get().range().get().size()-1:
            ret.cpp_index = make_shared[cpp_AIndex[shared_ptr[cpp_RangeBase]]] (self.cpp_index.get().A_plus(1))
            return ret
        else:
            raise StopIteration

    def lex(self):
        return self.cpp_index.get().lex()

    def dim(self):
        return self.cpp_index.get().dim()

    def stringMeta(self):
        return self.cpp_index.get().stringMeta()
    
    def get(self):
        cdef Range r = Range()
        r.cpp_range = self.cpp_index.get().A_get()
        return r