# coding: utf-8
from .basic import Basic
from sympy.core.singleton import Singleton
from sympy.core.compatibility import with_metaclass
from sympy import Eq, Ne, Lt, Gt, Le, Ge
default_precision = {'real': 8, 'int': 4, 'complex': 8, 'bool':1}
dtype_and_precsision_registry = {'real':('real',8),
'double':('real',8),
'float':('real',8),
'float32':('real',4),
'float64':('real',8),
'complex':('complex',8),
'complex64':('complex',4),
'complex128':('complex',8),
'int8' :('int',1),
'int16':('int',2),
'int32':('int',4),
'int64':('int',8),
'int' :('int',4),
'integer':('int',4),
'bool' :('bool',1)}
[docs]class DataType(with_metaclass(Singleton, Basic)):
"""Base class representing native datatypes"""
_name = '__UNDEFINED__'
@property
def name(self):
return self._name
def __str__(self):
return str(self.name).lower()
[docs]class NativeBool(DataType):
_name = 'Bool'
pass
[docs]class NativeInteger(DataType):
_name = 'Int'
[docs]class NativeReal(DataType):
_name = 'Real'
pass
[docs]class NativeComplex(DataType):
_name = 'Complex'
pass
[docs]class NativeString(DataType):
_name = 'String'
pass
[docs]class NativeVoid(DataType):
_name = 'Void'
pass
[docs]class NativeNil(DataType):
_name = 'Nil'
pass
[docs]class NativeList(DataType):
_name = 'List'
pass
[docs]class NativeIntegerList(NativeInteger, NativeList):
_name = 'IntegerList'
pass
[docs]class NativeRealList(NativeReal, NativeList):
_name = 'RealList'
pass
[docs]class NativeComplexList(NativeComplex, NativeList):
_name = 'ComplexList'
pass
[docs]class NativeRange(DataType):
_name = 'Range'
pass
[docs]class NativeTensor(DataType):
_name = 'Tensor'
pass
[docs]class NativeParallelRange(NativeRange):
_name = 'ParallelRange'
pass
[docs]class NativeSymbol(DataType):
_name = 'Symbol'
pass
class NdArray(DataType):
_name = 'NdArray'
pass
class NdArrayInt(NdArray, NativeInteger):
_name = 'NdArrayInt'
pass
class NdArrayReal(NdArray, NativeReal):
_name = 'NdArrayReal'
pass
class NdArrayComplex(NdArray, NativeComplex):
_name = 'NdArrayComplex'
pass
# TODO to be removed
[docs]class CustomDataType(DataType):
_name = '__UNDEFINED__'
def __init__(self, name='__UNDEFINED__'):
self._name = name
[docs]class NativeGeneric(DataType):
_name = 'Generic'
pass
# ...
[docs]class VariableType(DataType):
def __init__(self, rhs, alias):
self._alias = alias
self._rhs = rhs
self._name = rhs._name
@property
def alias(self):
return self._alias
[docs]class FunctionType(DataType):
def __init__(self, domains):
self._domain = domains[0]
self._codomain = domains[1:]
self._domains = domains
self._name = ' -> '.join('{}'.format(V) for V in self._domains)
@property
def domain(self):
return self._domain
@property
def codomain(self):
return self._codomain
# ...
Bool = NativeBool()
Int = NativeInteger()
Real = NativeReal()
Complex = NativeComplex()
Void = NativeVoid()
Nil = NativeNil()
String = NativeString()
_Symbol = NativeSymbol()
IntegerList = NativeIntegerList()
RealList = NativeRealList()
ComplexList = NativeComplexList()
NdArray = NdArray()
NdArrayInt = NdArrayInt()
NdArrayReal = NdArrayReal()
NdArrayComplex = NdArrayComplex()
Generic = NativeGeneric()
dtype_registry = {'bool': Bool,
'int': Int,
'integer': Int,
'real' : Real,
'complex': Complex,
'void': Void,
'nil': Nil,
'symbol': _Symbol,
'*int': IntegerList,
'*real': RealList,
'*complex': ComplexList,
'ndarrayint': NdArrayInt,
'ndarrayinteger':NdArrayInt,
'ndarrayreal': NdArrayReal,
'ndarraycomplex': NdArrayComplex,
'*': Generic,
'str': String}
[docs]class UnionType(Basic):
def __new__(cls, args):
return Basic.__new__(cls, args)
@property
def args(self):
return self._args[0]
[docs]def DataTypeFactory(name, argnames=["_name"],
BaseClass=CustomDataType,
prefix=None,
alias=None,
is_iterable=False,
is_with_construct=False,
is_polymorphic=True):
def __init__(self, **kwargs):
for key, value in list(kwargs.items()):
# here, the argnames variable is the one passed to the
# DataTypeFactory call
if key not in argnames:
raise TypeError("Argument %s not valid for %s"
% (key, self.__class__.__name__))
setattr(self, key, value)
BaseClass.__init__(self, name=name[:-len("Class")])
if prefix is None:
prefix = 'Pyccel'
else:
prefix = 'Pyccel{0}'.format(prefix)
newclass = type(prefix + name, (BaseClass,),
{"__init__": __init__,
"_name": name,
"prefix": prefix,
"alias": alias,
"is_iterable": is_iterable,
"is_with_construct": is_with_construct,
"is_polymorphic": is_polymorphic})
return newclass
[docs]def is_pyccel_datatype(expr):
return isinstance(expr, CustomDataType)
# if not isinstance(expr, DataType):
# raise TypeError('Expecting a DataType instance')
# name = expr.__class__.__name__
# return name.startswith('Pyccel')
# TODO improve and remove try/except
[docs]def is_iterable_datatype(dtype):
"""Returns True if dtype is an iterable class."""
try:
if is_pyccel_datatype(dtype):
return dtype.is_iterable
elif isinstance(dtype, (NativeRange, NativeTensor)):
return True
else:
return False
except:
return False
[docs]def get_default_value(dtype):
"""Returns the default value of a native datatype."""
if isinstance(dtype, NativeInteger):
value = 0
elif isinstance(dtype, NativeReal):
value = 0.0
elif isinstance(dtype, NativeComplex):
value = 0.0
elif isinstance(dtype, NativeBool):
value = BooleanFalse()
else:
raise TypeError('Unknown type')
return value
# TODO improve and remove try/except
[docs]def is_with_construct_datatype(dtype):
"""Returns True if dtype is an with_construct class."""
try:
if is_pyccel_datatype(dtype):
return dtype.is_with_construct
else:
return False
except:
return False
# TODO check the use of Reals
[docs]def datatype(arg):
"""Returns the datatype singleton for the given dtype.
arg : str or sympy expression
If a str ('bool', 'int', 'real','complex', or 'void'), return the
singleton for the corresponding dtype. If a sympy expression, return
the datatype that best fits the expression. This is determined from the
assumption system. For more control, use the `DataType` class directly.
Returns:
DataType
"""
if isinstance(arg, str):
if arg.lower() not in dtype_registry:
raise ValueError("Unrecognized datatype " + arg)
return dtype_registry[arg]
if isinstance(arg, DataType):
return dtype_registry[arg.dtype.name.lower()]
else:
raise TypeError('Expecting a DataType')
[docs]def sp_dtype(expr):
"""
return the datatype of a sympy types expression
"""
if expr.is_integer:
return 'int'
elif expr.is_real:
return 'real'
elif expr.is_complex:
return 'complex'
elif expr.is_Boolean:
return 'bool'
elif isinstance(expr,(Eq, Ne, Lt, Gt, Le, Ge)):
return 'bool'
else:
raise TypeError('Unknown datatype {0}'.format(str(expr)))
[docs]def str_dtype(dtype):
"""
return a sympy datatype as string
dtype: str, Native Type
"""
if isinstance(dtype, str):
if dtype == 'int':
return 'integer'
elif dtype== 'real':
return 'real'
else:
return dtype
if isinstance(dtype, NativeInteger):
return 'integer'
elif isinstance(dtype, NativeReal):
return 'real'
elif isinstance(dtype, NativeComplex):
return 'complex'
elif isinstance(dtype, NativeBool):
return 'bool'
else:
raise TypeError('Unknown datatype {0}'.format(str(dtype)))