# coding: utf-8
"""Print to F90 standard. Trying to follow the information provided at
www.fortran90.org as much as possible."""
import string
from itertools import groupby, chain
import numpy as np
from sympy import Lambda
from sympy.core import Symbol
from sympy.core import Float as sp_Float, Integer as sp_Integer
from sympy.core import S, Add, N
from sympy.core import Tuple
from sympy.core.function import Function
from sympy.core.compatibility import string_types
from sympy.printing.precedence import precedence
from sympy import Eq, Ne, true, false
from sympy import Atom, Indexed
from sympy import preorder_traversal
from sympy.core.numbers import NegativeInfinity as NINF
from sympy.core.numbers import Infinity as INF
from sympy import Mod
from sympy.utilities.iterables import iterable
from sympy.logic.boolalg import Boolean, BooleanTrue, BooleanFalse
from sympy.logic.boolalg import And, Not, Or, true, false
from pyccel.ast import Zeros, Array, Int, Shape, Sum, Rand,Real,Complex
from pyccel.ast.core import get_initial_value
from pyccel.ast.core import get_iterable_ranges
from pyccel.ast.core import AddOp, MulOp, SubOp, DivOp
from pyccel.ast.core import String
from pyccel.ast.core import ClassDef
from pyccel.ast.core import Nil
from pyccel.ast.core import Module
from pyccel.ast.core import SeparatorComment, CommentBlock
from pyccel.ast.core import ConstructorCall
from pyccel.ast.core import FunctionDef, Interface
from pyccel.ast.core import Subroutine
from pyccel.ast.core import ZerosLike
from pyccel.ast.core import Return
from pyccel.ast.core import ValuedArgument
from pyccel.ast.core import ErrorExit, Exit
from pyccel.ast.core import Range, Product, Block , Zip, Enumerate, Map
from pyccel.ast.core import get_assigned_symbols
from pyccel.ast.core import (Assign, AugAssign, Variable, CodeBlock,
Declare, ValuedVariable,
Len, FunctionalFor,
IndexedElement, Slice, List, Dlist,
DottedName, AsName, DottedVariable,
Print, If, Nil)
from pyccel.ast.datatypes import DataType, is_pyccel_datatype
from pyccel.ast.datatypes import is_iterable_datatype, is_with_construct_datatype
from pyccel.ast.datatypes import NativeBool, NativeSymbol, NativeString, NativeList
from pyccel.ast.datatypes import NativeComplex, NativeReal, NativeInteger
from pyccel.ast.datatypes import NativeRange, NativeTensor
from pyccel.ast.datatypes import CustomDataType
from pyccel.codegen.printing.codeprinter import CodePrinter
from pyccel.ast.parallel.mpi import MPI
from pyccel.ast.parallel.openmp import OMP_For
from pyccel.ast.parallel.openacc import ACC_For
from collections import OrderedDict
import functools
import operator
# TODO: add examples
# TODO: use _get_statement when returning a string
__all__ = ["FCodePrinter", "fcode"]
known_functions = {
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"log": "log",
"exp": "exp",
"erf": "erf",
"Abs": "abs",
"sign": "sign",
"conjugate": "conjg"
}
_default_methods = {
'__init__': 'create',
'__del__' : 'free',
}
[docs]class FCodePrinter(CodePrinter):
"""A printer to convert sympy expressions to strings of Fortran code"""
printmethod = "_fcode"
language = "Fortran"
_default_settings = {
'order': None,
'full_prec': 'auto',
'precision': 15,
'user_functions': {},
'human': True,
'source_format': 'fixed',
'tabwidth': 2,
'contract': True,
'standard': 77
}
_operators = {
'and': '.and.',
'or': '.or.',
'xor': '.neqv.',
'equivalent': '.eqv.',
'not': '.not. ',
}
_relationals = {
'!=': '/=',
}
def __init__(self, settings={}):
CodePrinter.__init__(self, settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
def _get_statement(self, codestring):
return codestring
def _get_comment(self, text):
return "! {0}".format(text)
def _format_code(self, lines):
return self._wrap_fortran(self.indent_code(lines))
def _traverse_matrix_indices(self, mat):
rows, cols = mat.shape
return ((i, j) for j in range(cols) for i in range(rows))
# ============ Elements ============ #
def _print_Module(self, expr):
name = self._print(expr.name)
name = name.replace('.', '_')
if not name.startswith('mod_'):
name = 'mod_{0}'.format(name)
imports = '\n'.join(self._print(i) for i in expr.imports)
decs = '\n'.join(self._print(i) for i in expr.declarations)
body = ''
# ...
sep = self._print(SeparatorComment(40))
interfaces = ''
if expr.interfaces:
interfaces = '\n'.join(self._print(i) for i in expr.interfaces if not i.hide)
for interface in expr.interfaces:
if not interface.hide:
for i in interface.functions:
body = ('{body}\n'
'{sep}\n'
'{f}\n'
'{sep}\n').format(body=body, sep=sep, f=self._print(i))
if expr.funcs:
for i in expr.funcs:
body = ('{body}\n'
'{sep}\n'
'{f}\n'
'{sep}\n').format(body=body, sep=sep, f=self._print(i))
# ...
# ...
classes = ''
for i in expr.classes:
# update decs with declarations from ClassDef
c_decs, c_funcs = self._print(i)
decs = '{0}\n{1}'.format(decs, c_decs)
body = '{0}\n{1}\n'.format(body, c_funcs)
# ...
if expr.funcs or expr.classes or expr.interfaces:
body = '\n contains\n{0}'.format(body)
return ('module {name}\n'
'{imports}\n'
'implicit none\n'
'{decs}\n'
'{interfaces}\n'
'{body}\n'
'end module\n').format(name=name,
imports=imports,
decs=decs,
interfaces=interfaces,
body=body)
def _print_Program(self, expr):
name = 'prog_{0}'.format(self._print(expr.name))
name = name.replace('.', '_')
modules = ''
mpi = False
#we use this to detect of we are using so that we can add
# mpi_init and mpi_finalize in the code instruction
# TODO should we find a better way to do this?
imports = list(expr.imports)
for i in expr.imports:
if 'mpi4py' == str(i.target[0]):
mpi = True
imports = '\n'.join(self._print(i) for i in imports)
funcs = ''
body = '\n'.join(self._print(i) for i in expr.body)
decs = expr.declarations
func_in_func = False
for func in expr.funcs:
for i in func.body:
if isinstance(i, FunctionDef):
func_in_func = True
break
if expr.classes or expr.interfaces or func_in_func:
# TODO shall we use expr.variables? or have a more involved algo
# we will need to walk through the expression and see what are
# the variables that are needed in the definitions of classes
variables = []
for i in expr.interfaces:
variables += i.functions[0].global_vars
for i in expr.funcs:
variables += i.global_vars
variables =list(set(variables))
for i in range(len(decs)):
#remove variables that are declared in the modules
if decs[i].variable in variables:
decs[i] = None
decs = [i for i in decs if i]
module_utils = Module(expr.name, list(variables),
expr.funcs, expr.interfaces, expr.classes,
imports=expr.imports)
modules = self._print(module_utils)
imports = ('{imports}\n'
'use mod_{name}\n').format(imports=imports, name=expr.name)
else:
# ... uncomment this later and remove it from the top
# decs = '\n'.join(self._print(i) for i in expr.declarations)
# ...
# ...
sep = self._print(SeparatorComment(40))
funcs = ''
if expr.funcs:
for i in expr.funcs:
funcs = ('{funcs}\n'
'{sep}\n'
'{f}\n'
'{sep}\n').format(funcs=funcs, sep=sep, f=self._print(i))
funcs = 'contains\n{0}'.format(funcs)
# ...
decs = '\n'.join(self._print(i) for i in decs)
if mpi:
#TODO shuold we add them like this ?
body = 'call mpi_init(ierr)\n'+\
'\nallocate(status(0:-1 + mpi_status_size)) '+\
'\n status = 0\n'+\
body +\
'\ncall mpi_finalize(ierr)'
decs += '\ninteger :: ierr = -1' +\
'\n integer, allocatable :: status (:)'
return ('{modules}\n'
'program {name}\n'
'{imports}\n'
'implicit none\n'
'{decs}\n'
'{body}\n'
'{funcs}\n'
'end program {name}\n').format(name=name,
imports=imports,
decs=decs,
body=body,
funcs=funcs,
modules=modules)
def _print_Import(self, expr):
prefix_as = ''
source = ''
if expr.source is None:
prefix = 'use'
else:
if isinstance(expr.source, DottedName):
source = '_'.join(self._print(j) for j in expr.source.name)
else:
source = self._print(expr.source)
prefix = 'use {}, only:'.format(source)
prefix_as = 'use {},'.format(source)
# TODO - improve
# importing of pyccel extensions is not printed
if source in ['numpy', 'scipy', 'itertools','math']:
return ''
if 'mpi4py' == str(expr.target[0]):
return 'use mpi'
code = ''
for i in expr.target:
if isinstance(i, AsName):
target = '{name} => {target}'.format(name=self._print(i.name),
target=self._print(i.target))
line = '{prefix} {target}'.format(prefix=prefix_as,
target=target)
elif isinstance(i, DottedName):
target = '_'.join(self._print(j) for j in i.name)
line = '{prefix} {target}'.format(prefix=prefix,
target=target)
elif isinstance(i, str):
line = '{prefix} {target}'.format(prefix=prefix,
target=str(i))
elif isinstance(i, Symbol):
line = '{prefix} {target}'.format(prefix=prefix,
target=str(i.name))
else:
raise TypeError('Expecting str, Symbol, DottedName or AsName, '
'given {}'.format(type(i)))
# TODO keep `\n` ?
# code = '{code}{line}'.format(code=code, line=line)
code = '{code}\n{line}'.format(code=code, line=line)
# in some cases, the source is given as a string (when using metavar)
code = code.replace("'", '')
return self._get_statement(code)
def _print_TupleImport(self, expr):
code = '\n'.join(self._print(i) for i in expr.imports)
return self._get_statement(code)
# TODO
def _print_FromImport(self, expr):
fil = self._print(expr.fil)
if isinstance(expr.fil, DottedName):
# pyccel-extension case
if expr.fil.name[0] == 'pyccelext':
fil = '_'.join(self._print(i) for i in expr.fil.name)
fil = 'mod_{0}'.format(fil)
else:
fil = '_'.join(self._print(i) for i in expr.fil.name)
fil = 'mod_{0}'.format(fil)
if not expr.funcs:
return 'use {0}'.format(fil)
elif isinstance(expr.funcs, str):
funcs = self._print(expr.funcs)
return 'use {0}, only: {1}'.format(fil, funcs)
elif isinstance(expr.funcs, (tuple, list, Tuple)):
funcs = ', '.join(self._print(f) for f in expr.funcs)
return 'use {0}, only: {1}'.format(fil, funcs)
else:
raise TypeError('Wrong type for funcs')
def _print_Print(self, expr):
args = []
for f in expr.expr:
if isinstance(f, str):
args.append("'{}'".format(f))
elif isinstance(f, Tuple):
for i in f:
args.append("{}".format(self._print(i)))
else:
args.append("{}".format(self._print(f)))
fs = ', '.join(i for i in args)
code = 'print *, {0}'.format(fs)
return self._get_statement(code)
def _print_SymbolicPrint(self, expr):
# for every expression we will generate a print
_iprint = lambda e: "print *, 'sympy> {}'".format(e)
code = ''
for a in expr.expr:
code = '{code}{p}'.format(code=code, p=_iprint(a))
return self._get_statement(code)
def _print_Comment(self, expr):
txt = self._print(expr.text)
comments = []
while len(txt)>60:
try:
index = txt[60:].index(' ')+60
except:
index = 60
comments.append(txt[:index])
txt = txt[index:]
else:
comments.append(txt)
comments = ['! '+ comment for comment in comments]
comments = '\n'.join(comment for comment in comments)
return comments
def _print_CommentBlock(self, expr):
txts = expr.comments
comments = []
for txt in txts:
while len(txt)>60:
try:
index = txt[60:].index(' ')+60
except:
index = 60
comments.append(txt[:index])
txt = txt[index:]
else:
comments.append(txt)
txts = comments
ln = max(len(i) for i in txts)
if ln<20:
ln = 20
top = '!' + '_'*int((ln-12)/2) + 'CommentBlock' + '_'*int((ln-12)/2) + '!'
ln = len(top)
bottom = '!' + '_'*(ln-2) + '!'
for i in range(len(txts)):
txts[i] = '!' + txts[i] + ' '*(ln -2 - len(txts[i])) + '!'
body = '\n'.join(i for i in txts)
return ('{0}\n'
'{1}\n'
'{2}').format(top, body, bottom)
def _print_EmptyLine(self, expr):
return ''
def _print_NewLine(self, expr):
return '\n'
def _print_AnnotatedComment(self, expr):
accel = self._print(expr.accel)
txt = str(expr.txt)
return '!${0} {1}'.format(accel, txt)
def _print_Tuple(self, expr):
import numpy
shape = numpy.shape(expr)
if len(shape)>1:
arg = functools.reduce(operator.concat, expr)
elements = ','.join(self._print(i) for i in arg)
return 'reshape((/ '+ elements + ' /), '+ self._print(Tuple(*shape)) + ')'
fs = ', '.join(self._print(f) for f in expr)
return '(/ {0} /)'.format(fs)
def _print_Variable(self, expr):
return self._print(expr.name)
def _print_Constant(self, expr):
val = sp_Float(expr.value)
return self._print(val)
def _print_ValuedArgument(self, expr):
name = self._print(expr.name)
value = self._print(expr.value)
code = '{0}={1}'.format(name, value)
return code
def _print_DottedVariable(self, expr):
if isinstance(expr.args[1], Function):
func = expr.args[1].func
name = func.__name__
# ...
code_args = ''
code_args = ', '.join(self._print(i) for i in expr.args[1].args)
code = '{0}({1})'.format(name, code_args)
# ...
# ...
code = '{0}%{1}'.format(self._print(expr.args[0]), code)
if isinstance(func, Subroutine):
code = 'call {0}'.format(code)
return code
return self._print(expr.args[0]) + '%' +self._print(expr.args[1])
def _print_DottedName(self, expr):
return ' % '.join(self._print(n) for n in expr.name)
def _print_Concatinate(self, expr):
args = expr.args
if expr.is_list:
code = ','.join(self._print(a) for a in expr.args)
return '[' + code + ']'
else:
code = '//'.join('trim('+self._print(a)+')' for a in expr.args)
return code
def _print_Lambda(self, expr):
return '"{args} -> {expr}"'.format(args=expr.variables, expr=expr.expr)
def _print_ZerosLike(self, expr):
lhs = self._print(expr.lhs)
rhs = self._print(expr.rhs)
if isinstance(expr.rhs, IndexedElement):
shape = []
for i in expr.rhs.indices:
if isinstance(i, Slice):
shape.append(i)
rank = len(shape)
else:
rank = expr.rhs.rank
rs = []
for i in range(1, rank+1):
l = 'lbound({0},{1})'.format(rhs, str(i))
u = 'ubound({0},{1})'.format(rhs, str(i))
r = '{0}:{1}'.format(l, u)
rs.append(r)
shape = ', '.join(self._print(i) for i in rs)
init_value = self._print(expr.init_value)
code = ('allocate({lhs}({shape}))\n'
'{lhs} = {init_value}').format(lhs=lhs,
shape=shape,
init_value=init_value)
return self._get_statement(code)
def _print_SumFunction(self, expr):
return str(expr)
def _print_Len(self, expr):
return 'size(%s,1)'%(self._print(expr.arg))
def _print_Sum(self, expr):
return expr.fprint(self._print)
def _print_Shape(self, expr):
return expr.fprint(self._print)
def _print_Zeros(self, expr):
return expr.fprint(self._print)
def _print_Array(self, expr):
return expr.fprint(self._print)
def _print_Int(self, expr):
return expr.fprint(self._print)
def _print_Real(self, expr):
return expr.fprint(self._print)
def _print_Rand(self, expr):
return expr.fprint(self._print)
def _print_Min(self, expr):
args = expr.args
if len(args) == 1:
arg = args[0]
code = 'minval({0})'.format(self._print(arg))
else:
code = ','.join(self._print(arg) for arg in args)
code = 'min('+code+')'
return self._get_statement(code)
def _print_Max(self, expr):
args = expr.args
args = ','.join(self._print(arg) for arg in args)
code = 'max({0})'.format(args)
return self._get_statement(code)
def _print_Dot(self, expr):
return self._get_statement('dot_product(%s,%s)'%(self._print(expr.expr_l), self._print(expr.expr_r)))
def _print_Ceil(self, expr):
return self._get_statement('ceiling(%s)'%(self._print(expr.rhs)))
def _print_Mod(self, expr):
args = ','.join(self._print(i) for i in expr.args)
return 'modulo({})'.format(args)
def _print_Sign(self, expr):
# TODO use the appropriate precision from rhs
return self._get_statement('sign(1.0d0,%s)'%(self._print(expr.rhs)))
# ... MACROS
def _print_MacroShape(self, expr):
var = expr.argument
if not isinstance(var, (Variable, IndexedElement)):
raise TypeError('Expecting a variable, given {}'.format(type(var)))
shape = None
if isinstance(var, Variable):
shape = var.shape
if shape is None:
rank = var.rank
shape = []
for i in range(0, rank):
l = 'lbound({var},{i})'.format(var=self._print(var),
i=self._print(i+1))
u = 'ubound({var},{i})'.format(var=self._print(var),
i=self._print(i+1))
s = '{u}-{l}+1'.format(u=u, l=l)
shape.append(s)
if len(shape) == 1:
shape = shape[0]
elif not(expr.index is None):
if expr.index < len(shape):
shape = shape[expr.index]
else:
shape = '1'
code = '{}'.format(self._print(shape))
return self._get_statement(code)
# ...
def _print_MacroType(self, expr):
dtype = self._print(expr.argument.dtype)
prec = expr.argument.precision
if dtype == 'integer':
if prec==4:
return 'MPI_INT'
else:
raise NotImplementedError('TODO')
elif dtype == 'real':
if prec==8:
return 'MPI_DOUBLE'
if prec==4:
return 'MPI_FLOAT'
else:
raise NotImplementedError('TODO')
else:
raise NotImplementedError('TODO')
def _print_MacroCount(self, expr):
var = expr.argument
#TODO calculate size when type is pointer
# it must work according to fortran documentation
# but it raises somehow an error when it's a pointer
# and shape is None
if isinstance(var, Variable):
shape = var.shape
if not isinstance(shape,(tuple,list,Tuple)):
shape = [shape]
rank = len(shape)
if shape is None:
return 'size({})'.format(self._print(var))
elif isinstance(var, IndexedElement):
_shape = var.base.shape
if _shape is None:
return 'size({})'.format(self._print(var))
shape = []
for (s, i) in zip(_shape, var.indices):
if isinstance(i, Slice):
if i.start is None and i.end is None:
shape.append(s)
elif i.start is None:
if (isinstance(i.end, (int, sp_Integer)) and i.end>0) or not(isinstance(i.end, (int, sp_Integer))):
shape.append(i.end)
elif i.end is None:
if (isinstance(i.start, (int, sp_Integer)) and i.start<s-1) or not(isinstance(i.start, (int, sp_Integer))):
shape.append(s-i.start)
else:
shape.append(i.end-i.start+1)
rank = len(shape)
else:
raise NotImplementedError('TODO')
if rank == 0:
return '1'
return str(functools.reduce(operator.mul, shape ))
def _print_Declare(self, expr):
# ... ignored declarations
# we don't print the declaration if iterable object
if is_iterable_datatype(expr.dtype):
return ''
if is_with_construct_datatype(expr.dtype):
return ''
if isinstance(expr.dtype, NativeSymbol):
return ''
if isinstance(expr.dtype, (NativeRange, NativeTensor)):
return ''
# meta-variables
if (isinstance(expr.variable, Variable) and
str(expr.variable.name).startswith('__')):
return ''
# ...
# ... TODO improve
# Group the variables by intent
var = expr.variable
arg_types = type(var)
rank = var.rank
allocatable = var.allocatable
shape = var.shape
is_pointer = var.is_pointer
is_target = var.is_target
is_polymorphic = var.is_polymorphic
is_optional = var.is_optional
is_static = expr.static
intent = expr.intent
if isinstance(shape, tuple) and len(shape) ==1:
shape = shape[0]
# ...
# ... print datatype
if isinstance(expr.dtype, CustomDataType):
dtype = expr.dtype
name = dtype.__class__.__name__
prefix = dtype.prefix
alias = dtype.alias
if not var.is_polymorphic:
sig = 'type'
elif dtype.is_polymorphic:
sig = 'class'
else:
sig = 'type'
if alias is None:
name = name.replace(prefix, '')
else:
name = alias
dtype = '{0}({1})'.format(sig, name)
else:
dtype = self._print(expr.dtype)
# ...
if isinstance(expr.dtype, NativeString):
if expr.intent:
dtype = dtype[:9] +'(len =*)'
#TODO improve ,this is the case of character as argument
else:
dtype += '(kind={0})'.format(str(expr.variable.precision))
code_value = ''
if expr.value:
code_value = ' = {0}'.format(expr.value)
decs = []
vstr = self._print(expr.variable.name)
# arrays are 0-based in pyccel, to avoid ambiguity with range
s = '0'
e = ''
enable_alloc = True
if not(is_static) and (allocatable or (var.shape is None)):
s = ''
rankstr = ''
allocatablestr = ''
# TODO improve
if ((rank == 1) and (isinstance(shape, (int, sp_Integer, Variable, Add))) and
(not(allocatable or is_pointer) or is_static)):
rankstr = '({0}:{1})'.format(self._print(s), self._print(shape-1))
enable_alloc = False
elif ((rank > 0) and (isinstance(shape, (Tuple, tuple))) and
(is_target or not(allocatable or is_pointer) or is_static)):
#TODO fix bug when we inclue shape of type list
rankstr = ','.join('{0}:{1}'.format(self._print(s),
self._print(i-1)) for i in shape)
rankstr = '({rank})'.format(rank=rankstr)
enable_alloc = False
elif (rank > 0) and allocatable and intent:
rankstr = ','.join('0:' for f in range(0, rank))
rankstr = '(' + rankstr + ')'
elif (rank > 0) and (allocatable or is_pointer):
rankstr = ','.join(':' for f in range(0, rank))
rankstr = '(' + rankstr + ')'
# else:
# raise NotImplementedError('Not treated yet')
if not is_static:
if is_pointer:
allocatablestr = ', pointer'
elif is_target:
allocatablestr = ', target'
elif allocatable and not intent:
allocatablestr = ', allocatable'
optionalstr = ''
if is_optional:
optionalstr = ', optional'
allocatablestr = allocatablestr + optionalstr
if intent:
decs.append('{0}, intent({1}) {2} :: {3} {4}'.
format(dtype, intent, allocatablestr, vstr, rankstr))
else:
args = [dtype, allocatablestr, vstr, rankstr, code_value]
decs.append('{0}{1} :: {2} {3} {4}'.
format(*args))
return '\n'.join(decs)
def _print_AliasAssign(self, expr):
code = ''
lhs = expr.lhs
rhs = expr.rhs
if isinstance(rhs, Dlist):
return 'allocate({lhs}(0:{length}-1))\n {lhs} = {init_value}'.format(
lhs = self._print(lhs),
length=self._print(rhs.length),
init_value=self._print(rhs.val))
# TODO improve
op = '=>'
if isinstance(lhs, Variable) and (lhs.rank > 0) and (not lhs.is_pointer or not isinstance(rhs, Atom)):
if not isinstance(rhs, Atom) and not isinstance(rhs, Indexed):
# case of rhs an expression and lhs is pointer we then allocate the memory for it
for i in list(preorder_traversal(rhs)):
if isinstance(i, (Variable, DottedVariable)) and i.rank>0:
rhs = i
break
#TODO improve we only need to allocate the variable without setting it to zero
stmt = ZerosLike(lhs=lhs, rhs=rhs)
code += self._print(stmt)
code += '\n'
op = '='
code += '{lhs} {op} {rhs}'.format(lhs=self._print(expr.lhs),
op=op,
rhs=self._print(expr.rhs))
return self._get_statement(code)
def _print_CodeBlock(self, expr):
return '\n'.join(self._print(i) for i in expr.body)
def _print_Assign(self, expr):
lhs_code = self._print(expr.lhs)
is_procedure = False
# we don't print Range, Tensor
# TODO treat the case of iterable classes
if isinstance(expr.rhs, NINF):
rhs_code = '-Huge({0})'.format(lhs_code)
return '{0} = {1}'.format(lhs_code, rhs_code)
if isinstance(expr.rhs, INF):
rhs_code = 'Huge({0})'.format(lhs_code)
return '{0} = {1}'.format(lhs_code, rhs_code)
if isinstance(expr.rhs, (Range, Product)):
return ''
if isinstance(expr.rhs, Len):
rhs_code = self._print(expr.rhs)
return '{0} = {1}'.format(lhs_code, rhs_code)
if isinstance(expr.rhs, (Int, Real, Complex)):
lhs = self._print(expr.lhs)
rhs = expr.rhs.fprint(self._print)
return '{0} = {1}'.format(lhs,rhs)
if isinstance(expr.rhs, (Zeros, Array, Shape)):
return expr.rhs.fprint(self._print, expr.lhs)
if isinstance(expr.rhs, ZerosLike):
return self._print(ZerosLike(lhs=expr.lhs,rhs=expr.rhs.rhs))
if isinstance(expr.rhs, Mod):
lhs = self._print(expr.lhs)
args = ','.join(self._print(i) for i in expr.rhs.args)
rhs = 'modulo({})'.format(args)
return '{0} = {1}'.format(lhs,rhs)
elif isinstance(expr.rhs, Shape):
a = expr.rhs.rhs
lhs = self._print(expr.lhs)
rhs = self._print(a)
if isinstance(a, IndexedElement):
shape = []
for i in a.indices:
if isinstance(i, Slice):
shape.append(i)
rank = len(shape)
else:
rank = a.rank
code = 'allocate({0}(0:{1}-1)) ; {0} = 0'.format(lhs, rank)
rs = []
for i in range(0, rank):
l = 'lbound({0},{1})'.format(rhs, str(i+1))
u = 'ubound({0},{1})'.format(rhs, str(i+1))
r = '{3}({2}) = {1}-{0}'.format(l, u, str(i), lhs)
rs.append(r)
sizes = '\n'.join(self._print(i) for i in rs)
code = '{0}\n{1}'.format(code, sizes)
return self._get_statement(code)
elif isinstance(expr.rhs, FunctionDef):
rhs_code = self._print(expr.rhs.name)
is_procedure = expr.rhs.is_procedure
elif isinstance(expr.rhs, ConstructorCall):
func = expr.rhs.func
name = str(func.name)
# TODO uncomment later
# # we don't print the constructor call if iterable object
# if this.dtype.is_iterable:
# return ''
#
# # we don't print the constructor call if with construct object
# if this.dtype.is_with_construct:
# return ''
if name == "__init__":
name = "create"
rhs_code = self._print(name)
rhs_code = '{0} % {1}'.format(lhs_code, rhs_code)
#TODO use is_procedure property
is_procedure = (expr.rhs.kind == 'procedure')
code_args = ', '.join(self._print(i) for i in expr.rhs.arguments)
return 'call {0}({1})'.format(rhs_code, code_args)
elif isinstance(expr.rhs, Function):
# in the case of a function that returns a list,
# we should append them to the procedure arguments
name = type(expr.rhs).__name__
rhs_code = self._print(name)
args = expr.rhs.args
code_args = ', '.join(self._print(i) for i in args)
if isinstance(expr.lhs, (tuple, list, Tuple)):
lhs_code = ', '.join(self._print(i) for i in expr.lhs)
code = 'call {0}({1}, {2})'.format(rhs_code, code_args, lhs_code)
return self._get_statement(code)
rhs_code = '{0}({1})'.format(rhs_code, code_args)
code = '{0} = {1}'.format(lhs_code, rhs_code)
return self._get_statement(code)
elif (isinstance(expr.lhs, Variable) and
expr.lhs.dtype == NativeSymbol()):
return ''
else:
rhs_code = self._print(expr.rhs)
# print("ASSIGN = ", rhs_code)
code = ''
if (expr.status == 'unallocated') and not (expr.like is None):
stmt = ZerosLike(lhs=lhs_code, rhs=expr.like)
code += self._print(stmt)
code += '\n'
if not is_procedure:
code += '{0} = {1}'.format(lhs_code, rhs_code)
# else:
# code_args = ''
# func = expr.rhs
# # func here is of instance FunctionCall
# cls_name = func.func.cls_name
# keys = func.func.arguments
# # for MPI statements, we need to add the lhs as the last argument
# # TODO improve
# if isinstance(func.func, MPI):
# if not func.arguments:
# code_args = lhs_code
# else:
# code_args = ', '.join(self._print(i) for i in func.arguments)
# code_args = '{0}, {1}'.format(code_args, lhs_code)
# else:
# _ij_print = lambda i, j: '{0}={1}'.format(self._print(i), \
# self._print(j))
#
# code_args = ', '.join(_ij_print(i, j) \
# for i, j in zip(keys, func.arguments))
# if (not func.arguments is None) and (len(func.arguments) > 0):
# if (not cls_name):
# code_args = ', '.join(self._print(i) for i in func.arguments)
# code_args = '{0}, {1}'.format(code_args, lhs_code)
# else:
# print('code_args > {0}'.format(code_args))
# code = 'call {0}({1})'.format(rhs_code, code_args)
return self._get_statement(code)
def _print_NativeBool(self, expr):
return 'logical'
def _print_NativeInteger(self, expr):
return 'integer'
def _print_NativeReal(self, expr):
return 'real'
def _print_NativeComplex(self, expr):
return 'complex'
def _print_BooleanTrue(self, expr):
return '.true.'
def _print_BooleanFalse(self, expr):
return '.false.'
def _print_NativeString(self, expr):
return 'character(len=280)'
#TODO fix improve later
def _print_DataType(self, expr):
return self._print(expr.name)
def _print_Equality(self, expr):
return '{0} == {1} '.format(self._print(expr.lhs), self._print(expr.rhs))
def _print_Unequality(self, expr):
return '{0} /= {1} '.format(self._print(expr.lhs), self._print(expr.rhs))
def _print_BooleanTrue(self, expr):
return '.True.'
def _print_BooleanFalse(self, expr):
return '.False.'
def _print_String(self, expr):
return expr.arg
def _print_Interface(self, expr):
# ... we don't print 'hidden' functions
name = self._print(expr.name)
if expr.functions[0].cls_name:
for k, m in list(_default_methods.items()):
name = name.replace(k, m)
cls_name = expr.cls_name
if not (cls_name == '__UNDEFINED__'):
name = '{0}_{1}'.format(cls_name, name)
else:
for i in _default_methods:
# because we may have a class Point with init: Point___init__
if i in name:
name = name.replace(i, _default_methods[i])
interface = 'interface ' + name +'\n'
functions = []
for f in expr.functions:
interface += 'module procedure ' + str(f.name)+'\n'
interface += 'end interface\n'
return interface
# def _print_With(self, expr):
# test = 'call '+self._print(expr.test) + '%__enter__()'
# body = '\n'.join(self._print(i) for i in expr.body)
# end = 'call '+self._print(expr.test) + '%__exit__()'
# code = ('{test}\n'
# '{body}\n'
# '{end}').format(test=test, body=body, end=end)
#TODO return code later
# expr.block
# return ''
def _print_Block(self, expr):
decs=[]
for i in expr.variables:
dec = Declare(i.dtype, i)
decs += [dec]
body = expr.body
body_code = '\n'.join(self._print(i) for i in body)
prelude = '\n'.join(self._print(i) for i in decs)
#case of no local variables
if len(decs) == 0:
return body_code
return ('{name} : Block\n'
'{prelude}\n'
'{body}\n'
'end Block {name}').format(name=expr.name, prelude=prelude, body=body_code)
def _print_FunctionDef(self, expr):
# ... we don't print 'hidden' functions
if expr.hide:
return ''
# ...
name = self._print(expr.name)
is_static = expr.is_static
if expr.cls_name:
for k, m in list(_default_methods.items()):
name = name.replace(k, m)
cls_name = expr.cls_name
if not (cls_name == '__UNDEFINED__'):
name = '{0}_{1}'.format(cls_name, name)
else:
for i in _default_methods:
# because we may have a class Point with init: Point___init__
if i in name:
name = name.replace(i, _default_methods[i])
out_args = []
decs = OrderedDict()
args_decs = OrderedDict()
# ... local variables declarations
for i in expr.local_vars:
dec = Declare(i.dtype, i)
decs[str(i)] = dec
# ...
# ...
body = expr.body
func_end = ''
if not expr.is_procedure:
result = expr.results[0]
# TODO uncomment and validate this
# expr = subs(expr, result, str(expr.name))
body = []
functions = []
for stmt in expr.body:
if isinstance(stmt, Declare):
pass
elif isinstance(stmt, FunctionDef):
functions += [stmt]
elif not isinstance(stmt, list): # for list of Results
body.append(stmt)
ret_type = self._print(result.dtype)
ret_type += '(kind={0})'.format(str(result.precision))
func_type = 'function'
rec = ''
if expr.is_recursive:
rec = 'recursive '
if result.allocatable or (result.rank > 0):
sig = '{0}function {1}'.format(rec, name)
var = Variable(result.dtype, result.name, \
rank=result.rank, \
allocatable=True, \
shape=result.shape)
dec = Declare(result.dtype, var)
args_decs[str(var)] = dec
else:
sig = '{0} {1}function {2}'.format(ret_type, rec, name)
func_end = ' result({0})'.format(result.name)
else:
# TODO compute intent
# a static function is always treated as a procedure
#TODO improve for functions without return
out_args = [result for result in expr.results]
for result in expr.results:
if result in expr.arguments:
dec = Declare(result.dtype, result, intent='inout', static=is_static)
else:
dec = Declare(result.dtype, result, intent='out', static=is_static)
args_decs[str(result)] = dec
sig = 'subroutine ' + name
func_type = 'subroutine'
names = [str(res.name) for res in expr.results]
body = []
functions = []
for stmt in expr.body:
if isinstance(stmt, Declare):
pass
elif isinstance(stmt, FunctionDef):
functions += [stmt]
elif not isinstance(stmt, Return):
body.append(stmt)
elif isinstance(stmt,Return):
body += [stmt]
# ... TODO improve to treat variables that are assigned within blocks: if, etc
symbols = get_assigned_symbols(expr.body)
assigned_names = [str(i) for i in symbols]
# ...
results_names = [str(i) for i in expr.results]
for arg in expr.arguments:
if str(arg) in results_names + assigned_names:
dec = Declare(arg.dtype, arg, intent='inout', static=is_static)
elif str(arg) == 'self':
dec = Declare(arg.dtype, arg, intent='inout', static=is_static)
else:
dec = Declare(arg.dtype, arg, intent='in', static=is_static)
args_decs[str(arg)] = dec
args_decs.update(decs)
decs = [v for k,v in args_decs.items()]
#remove parametres intent(inout) from out_args to prevent repetition
for i in expr.arguments:
if i in out_args:
out_args.remove(i)
arg_code = ', '.join(self._print(i) for i in chain( expr.arguments, out_args ))
body_code = '\n'.join(self._print(i) for i in body)
prelude = '\n'.join(self._print(i) for i in decs)
if len(functions)>0:
functions_code = '\n'.join(self._print(i) for i in functions)
body_code = body_code +'\ncontains \n' +functions_code
body_code = prelude + '\n\n' + body_code
imports = '\n'.join(self._print(i) for i in expr.imports)
return ('{0}({1}) {2}\n'
'{3}\n'
'implicit none\n'
# 'integer, parameter:: dp=kind(0.d0)\n'
'{4}\n'
'end {5}').format(sig, arg_code, func_end, imports, body_code, func_type)
def _print_Pass(self, expr):
return ''
def _print_Nil(self, expr):
return ''
def _print_Return(self, expr):
code = ''
if expr.stmt:
code += self._print(expr.stmt)+'\n'
code +='return'
return code
def _print_Del(self, expr):
# TODO: treate class case
code = ''
for var in expr.variables:
if isinstance(var, Variable):
dtype = var.dtype
if is_pyccel_datatype(dtype):
code = 'call {0} % free()'.format(self._print(var))
else:
code = 'deallocate({0}){1}'.format(self._print(var), code)
else:
msg = 'Only Variable is treated.'
msg += ' Given {0}'.format(type(var))
raise NotImplementedError(msg)
return code
def _print_ClassDef(self, expr):
# ... we don't print 'hidden' classes
if expr.hide:
return '', ''
# ...
name = self._print(expr.name)
base = None # TODO: add base in ClassDef
decs = '\n'.join(self._print(Declare(i.dtype, i)) for i in expr.attributes)
aliases = []
names = []
ls = [self._print(i.name) for i in expr.methods]
for i in ls:
j = i
if i in _default_methods:
j = _default_methods[i]
aliases.append(j)
names.append('{0}_{1}'.format(name, self._print(j)))
methods = '\n'.join('procedure :: {0} => {1}'.format(i, j) for i, j in zip(aliases, names))
for i in expr.interfaces:
names = ','.join('{0}_{1}'.format(name, self._print(j.name)) for j in i.functions)
methods += '\ngeneric, public :: {0} => {1}'.format(self._print(i.name), names)
methods += '\nprocedure :: {0}'.format(names)
options = ', '.join(i for i in expr.options)
sig = 'type, {0}'.format(options)
if not(base is None):
sig = '{0}, extends({1})'.format(sig, base)
code = ('{0} :: {1}').format(sig, name)
if len(decs) > 0:
code = ('{0}\n'
'{1}').format(code, decs)
if len(methods) > 0:
code = ('{0}\n'
'contains\n'
'{1}').format(code, methods)
decs = ('{0}\n'
'end type {1}').format(code, name)
sep = self._print(SeparatorComment(40))
# we rename all methods because of the aliasing
cls_methods = [i.rename('{0}'.format(i.name)) for i in expr.methods]
for i in expr.interfaces:
cls_methods += [j.rename('{0}'.format(j.name)) for j in i.functions]
methods = ''
for i in cls_methods:
methods = ('{methods}\n'
'{sep}\n'
'{f}\n'
'{sep}\n').format(methods=methods, sep=sep, f=self._print(i))
return decs, methods
def _print_Break(self, expr):
return 'exit'
def _print_Continue(self, expr):
return 'cycle'
def _print_AugAssign(self, expr):
lhs = expr.lhs
op = expr.op
rhs = expr.rhs
strict = expr.strict
status = expr.status
like = expr.like
if isinstance(op, AddOp):
rhs = lhs + rhs
elif isinstance(op, MulOp):
rhs = lhs * rhs
elif isinstance(op, SubOp):
rhs = lhs - rhs
# TODO fix bug with division of integers
elif isinstance(op, DivOp):
rhs = lhs / rhs
else:
raise ValueError('Unrecongnized operation', op)
stmt = Assign(lhs, rhs, strict=strict, status=status, like=like)
return self._print(stmt)
def _print_Range(self, expr):
start = self._print(expr.start)
stop = self._print(expr.stop-1)
step = self._print(expr.step)
return '{0}, {1}, {2}'.format(start, stop, step)
def _print_Tile(self, expr):
start = self._print(expr.start)
stop = self._print(expr.stop)
return '{0}, {1}'.format(start, stop)
def _print_FunctionalFor(self, expr):
allocate = ''
if expr.target and len(expr.target.shape)>0:
allocate = ','.join('0:{0}'.format(str(i)) for i in expr.target.shape)
allocate ='allocate({0}({1}))\n'.format(expr.target.name, allocate)
loops = '\n'.join(self._print(i) for i in expr.loops)
return allocate + loops
def _print_For(self, expr):
prolog = ''
epilog = ''
# ...
def _do_range(target, iter, prolog, epilog):
if not isinstance(iter, Range):
msg = "Only iterable currently supported is Range"
raise NotImplementedError(msg)
tar = self._print(target)
range_code = self._print(iter)
prolog += 'do {0} = {1}\n'.format(tar, range_code)
epilog = 'end do\n' + epilog
return prolog, epilog
# ...
# ...
def _iprint(i):
if isinstance(i, Block):
_prelude, _body = self._print_Block(i)
return '{0}'.format(_body)
else:
return '{0}'.format(self._print(i))
# ...
if not isinstance(expr.iterable, (Range, Product , Zip, Enumerate, Map)):
msg = "Only iterable currently supported are Range, "
msg += "Product"
raise NotImplementedError(msg)
if isinstance(expr.iterable, Range):
prolog, epilog = _do_range(expr.target, expr.iterable, \
prolog, epilog)
elif isinstance(expr.iterable, Product):
for i, a in zip(expr.target, expr.iterable.args):
itr_ = Range(a.shape[0])
prolog, epilog = _do_range(i, itr_, \
prolog, epilog)
elif isinstance(expr.iterable, Zip):
itr_ = Range(expr.iterable.element.shape[0])
prolog, epilog = _do_range(expr.target, itr_, \
prolog, epilog)
elif isinstance(expr.iterable, Enumerate):
itr_ = Range(Len(expr.iterable.element))
prolog, epilog = _do_range(expr.target, itr_, \
prolog, epilog)
elif isinstance(expr.iterable, Map):
itr_ = Range(Len(expr.iterable.args[1]))
prolog, epilog = _do_range(expr.target, itr_, \
prolog, epilog)
body = '\n'.join(_iprint(i) for i in expr.body)
return ('{prolog}'
'{body}\n'
'{epilog}').format(prolog=prolog, body=body, epilog=epilog)
# .....................................................
# OpenMP statements
# .....................................................
def _print_OMP_Parallel(self, expr):
clauses = ' '.join(self._print(i) for i in expr.clauses)
body = '\n'.join(self._print(i) for i in expr.body)
# ... TODO adapt get_statement to have continuation with OpenMP
prolog = '!$omp parallel {clauses}\n'.format(clauses=clauses)
epilog = '!$omp end parallel\n'
# ...
# ...
code = ('{prolog}'
'{body}\n'
'{epilog}').format(prolog=prolog, body=body, epilog=epilog)
# ...
return self._get_statement(code)
def _print_OMP_For(self, expr):
# ...
loop = self._print(expr.loop)
clauses = ' '.join(self._print(i) for i in expr.clauses)
nowait = ''
if not(expr.nowait is None):
nowait = 'nowait'
# ...
# ... TODO adapt get_statement to have continuation with OpenMP
prolog = '!$omp do {clauses}\n'.format(clauses=clauses)
epilog = '!$omp end do {0}\n'.format(nowait)
# ...
# ...
code = ('{prolog}'
'{loop}\n'
'{epilog}').format(prolog=prolog, loop=loop, epilog=epilog)
# ...
return self._get_statement(code)
def _print_OMP_NumThread(self, expr):
return 'num_threads({})'.format(self._print(expr.num_threads))
def _print_OMP_Default(self, expr):
status = expr.status
if status:
status = self._print(expr.status)
else:
status = ''
return 'default({})'.format(status)
def _print_OMP_ProcBind(self, expr):
status = expr.status
if status:
status = self._print(expr.status)
else:
status = ''
return 'proc_bind({})'.format(status)
def _print_OMP_Private(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'private({})'.format(args)
def _print_OMP_Shared(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'shared({})'.format(args)
def _print_OMP_FirstPrivate(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'firstprivate({})'.format(args)
def _print_OMP_LastPrivate(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'lastprivate({})'.format(args)
def _print_OMP_Copyin(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'copyin({})'.format(args)
def _print_OMP_Reduction(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
op = self._print(expr.operation)
return "reduction({0}: {1})".format(op, args)
def _print_OMP_Schedule(self, expr):
kind = self._print(expr.kind)
chunk_size = ''
if expr.chunk_size:
chunk_size = ', {0}'.format(self._print(expr.chunk_size))
return 'schedule({0}{1})'.format(kind, chunk_size)
def _print_OMP_Ordered(self, expr):
n_loops = ''
if expr.n_loops:
n_loops = '({0})'.format(self._print(expr.n_loops))
return 'ordered{0}'.format(n_loops)
def _print_OMP_Collapse(self, expr):
n_loops = '{0}'.format(self._print(expr.n_loops))
return 'collapse({0})'.format(n_loops)
def _print_OMP_Linear(self, expr):
variables= ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
step = self._print(expr.step)
return "linear({0}: {1})".format(variables, step)
def _print_OMP_If(self, expr):
return 'if({})'.format(self._print(expr.test))
# .....................................................
# .....................................................
# OpenACC statements
# .....................................................
def _print_ACC_Parallel(self, expr):
clauses = ' '.join(self._print(i) for i in expr.clauses)
body = '\n'.join(self._print(i) for i in expr.body)
# ... TODO adapt get_statement to have continuation with OpenACC
prolog = '!$acc parallel {clauses}\n'.format(clauses=clauses)
epilog = '!$acc end parallel\n'
# ...
# ...
code = ('{prolog}'
'{body}\n'
'{epilog}').format(prolog=prolog, body=body, epilog=epilog)
# ...
return self._get_statement(code)
def _print_ACC_For(self, expr):
# ...
loop = self._print(expr.loop)
clauses = ' '.join(self._print(i) for i in expr.clauses)
# ...
# ... TODO adapt get_statement to have continuation with OpenACC
prolog = '!$acc loop {clauses}\n'.format(clauses=clauses)
epilog = '!$acc end loop\n'
# ...
# ...
code = ('{prolog}'
'{loop}\n'
'{epilog}').format(prolog=prolog, loop=loop, epilog=epilog)
# ...
return self._get_statement(code)
def _print_ACC_Async(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'async({})'.format(args)
def _print_ACC_Auto(self, expr):
return 'auto'
def _print_ACC_Bind(self, expr):
return 'bind({})'.format(self._print(expr.variable))
def _print_ACC_Collapse(self, expr):
return 'collapse({0})'.format(self._print(expr.n_loops))
def _print_ACC_Copy(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'copy({})'.format(args)
def _print_ACC_Copyin(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'copyin({})'.format(args)
def _print_ACC_Copyout(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'copyout({})'.format(args)
def _print_ACC_Create(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'create({})'.format(args)
def _print_ACC_Default(self, expr):
return 'default({})'.format(self._print(expr.status))
def _print_ACC_DefaultAsync(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'default_async({})'.format(args)
def _print_ACC_Delete(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'delete({})'.format(args)
def _print_ACC_Device(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'device({})'.format(args)
def _print_ACC_DeviceNum(self, expr):
return 'collapse({0})'.format(self._print(expr.n_device))
def _print_ACC_DevicePtr(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'deviceptr({})'.format(args)
def _print_ACC_DeviceResident(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'device_resident({})'.format(args)
def _print_ACC_DeviceType(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'device_type({})'.format(args)
def _print_ACC_Finalize(self, expr):
return 'finalize'
def _print_ACC_FirstPrivate(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'firstprivate({})'.format(args)
def _print_ACC_Gang(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'gang({})'.format(args)
def _print_ACC_Host(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'host({})'.format(args)
def _print_ACC_If(self, expr):
return 'if({})'.format(self._print(expr.test))
def _print_ACC_Independent(self, expr):
return 'independent'
def _print_ACC_Link(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'link({})'.format(args)
def _print_ACC_NoHost(self, expr):
return 'nohost'
def _print_ACC_NumGangs(self, expr):
return 'num_gangs({0})'.format(self._print(expr.n_gang))
def _print_ACC_NumWorkers(self, expr):
return 'num_workers({0})'.format(self._print(expr.n_worker))
def _print_ACC_Present(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'present({})'.format(args)
def _print_ACC_Private(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'private({})'.format(args)
def _print_ACC_Reduction(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
op = self._print(expr.operation)
return "reduction({0}: {1})".format(op, args)
def _print_ACC_Self(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'self({})'.format(args)
def _print_ACC_Seq(self, expr):
return 'seq'
def _print_ACC_Tile(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'tile({})'.format(args)
def _print_ACC_UseDevice(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'use_device({})'.format(args)
def _print_ACC_Vector(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'vector({})'.format(args)
def _print_ACC_VectorLength(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'vector_length({})'.format(self._print(expr.n))
def _print_ACC_Wait(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'wait({})'.format(args)
def _print_ACC_Worker(self, expr):
args = ', '.join('{0}'.format(self._print(i)) for i in expr.variables)
return 'worker({})'.format(args)
# .....................................................
def _print_ForIterator(self, expr):
return self._print_For(expr)
depth = expr.depth
prolog = ''
epilog = ''
code = ''
# ...
def _do_range(target, iter, prolog, epilog):
tar = self._print(target)
range_code = self._print(iter)
prolog += 'do {0} = {1}\n'.format(tar, range_code)
epilog = 'end do\n' + epilog
return prolog, epilog
# ...
# ...
def _iprint(i):
if isinstance(i, Block):
_prelude, _body = self._print_Block(i)
return '{0}'.format(_body)
else:
return '{0}'.format(self._print(i))
# ...
# ...
if not isinstance(expr.iterable, (Variable, ConstructorCall)):
raise TypeError('iterable must be Variable or ConstructorCall.')
# ...
# ...
targets = expr.target
if isinstance(expr.iterable, Variable):
iters = expr.ranges
elif isinstance(expr.iterable, ConstructorCall):
iters = get_iterable_ranges(expr.iterable)
# ...
# ...
for i,a in zip(targets, iters):
prolog, epilog = _do_range(i, a, \
prolog, epilog)
body = '\n'.join(_iprint(i) for i in expr.body)
# ...
return ('{prolog}'
'{body}\n'
'{epilog}').format(prolog=prolog, body=body, epilog=epilog)
#def _print_Block(self, expr):
# body = '\n'.join(self._print(i) for i in expr.body)
# prelude = '\n'.join(self._print(i) for i in expr.declarations)
# return prelude, body
def _print_While(self,expr):
body = '\n'.join(self._print(i) for i in expr.body)
return ('do while ({test}) \n'
'{body}\n'
'end do').format(test=self._print(expr.test), body=body)
def _print_ErrorExit(self, expr):
# TODO treat the case of MPI
return 'STOP'
def _print_Assert(self, expr):
# we first create an If statement
# TODO: depending on a debug flag we should print 'PASSED' or not.
DEBUG = True
err = ErrorExit()
args = [(Not(expr.test), [Print(["'Assert Failed'"]), err])]
if DEBUG:
args.append((True, Print(["'PASSED'"])))
stmt = If(*args)
code = self._print(stmt)
return self._get_statement(code)
def _print_Is(self, expr):
if not isinstance(expr.rhs, Nil):
raise NotImplementedError('Only None rhs is allowed in Is statement')
lhs = self._print(expr.lhs)
return 'present({})'.format(lhs)
def _print_If(self, expr):
# ...
def _iprint(i):
if isinstance(i, Block):
_prelude, _body = self._print_Block(i)
return '{0}'.format(_body)
else:
return '{0}'.format(self._print(i))
# ...
lines = []
for i, (c, e) in enumerate(expr.args):
if i == 0:
lines.append("if (%s) then" % _iprint(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else")
else:
lines.append("else if (%s) then" % _iprint(c))
if isinstance(e, (list, tuple, Tuple)):
for ee in e:
lines.append(_iprint(ee))
else:
lines.append(_iprint(e))
lines.append("end if")
return "\n".join(lines)
def _print_MatrixElement(self, expr):
return "{0}({1}, {2})".format(expr.parent, expr.i + 1, expr.j + 1)
def _print_Add(self, expr):
# purpose: print complex numbers nicely in Fortran.
# collect the purely real and purely imaginary parts:
pure_real = []
pure_imaginary = []
mixed = []
for arg in expr.args:
if arg.is_number and arg.is_real:
pure_real.append(arg)
elif arg.is_number and arg.is_imaginary:
pure_imaginary.append(arg)
else:
mixed.append(arg)
if len(pure_imaginary) > 0:
if len(mixed) > 0:
PREC = precedence(expr)
term = Add(*mixed)
t = self._print(term)
if t.startswith('-'):
sign = "-"
t = t[1:]
else:
sign = "+"
if precedence(term) < PREC:
t = "(%s)" % t
return "cmplx(%s,%s) %s %s" % (
self._print(Add(*pure_real)),
self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
sign, t,
)
else:
return "cmplx(%s,%s)" % (
self._print(Add(*pure_real)),
self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
)
else:
return CodePrinter._print_Add(self, expr)
def _print_Header(self, expr):
return ''
def _print_ConstructorCall(self, expr):
func = expr.func
name = func.name
if name == "__init__":
name = "create"
name = self._print(name)
code_args = ''
if not(expr.arguments) is None:
code_args = ', '.join(self._print(i) for i in expr.arguments)
code = '{0}({1})'.format(name, code_args)
return self._get_statement(code)
def _print_Function(self, expr):
args = expr.args
name = type(expr).__name__
code_args = ', '.join(self._print(i) for i in args)
code = '{0}({1})'.format(name, code_args)
if isinstance(expr.func, Subroutine):
code = 'call ' + code
return self._get_statement(code)
def _print_ImaginaryUnit(self, expr):
# purpose: print complex numbers nicely in Fortran.
return "cmplx(0,1)"
def _print_int(self, expr):
return str(expr)
def _print_Mul(self, expr):
# purpose: print complex numbers nicely in Fortran.
if expr.is_number and expr.is_imaginary:
return "cmplx(0,%s)" % (
self._print(-S.ImaginaryUnit*expr)
)
else:
return CodePrinter._print_Mul(self, expr)
def _print_Pow(self, expr):
PREC = precedence(expr)
if expr.exp == -1:
one = sp_Float(1.0)
code = '{0}/{1}'.format(self._print(one), \
self.parenthesize(expr.base, PREC))
return code
elif expr.exp == 0.5:
if expr.base.is_integer:
# Fortan intrinsic sqrt() does not accept integer argument
if expr.base.is_Number:
return 'sqrt(%s.0d0)' % self._print(expr.base)
else:
return 'sqrt(dble(%s))' % self._print(expr.base)
else:
return 'sqrt(%s)' % self._print(expr.base)
else:
return CodePrinter._print_Pow(self, expr)
def _print_Float(self, expr):
printed = CodePrinter._print_Float(self, expr)
e = printed.find('e')
if e > -1:
return "%sd%s" % (printed[:e], printed[e + 1:])
return "%sd0" % printed
def _print_IndexedBase(self, expr):
return self._print(expr.name)
def _print_Indexed(self, expr):
inds = [i for i in expr.indices]
#indices of indexedElement of len==1 shouldn't be a Tuple
for i, ind in enumerate(inds):
if isinstance(ind, Tuple) and len(ind) == 1:
inds[i] = ind[0]
inds = [self._print(i) for i in inds]
return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds))
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_Slice(self, expr):
if expr.start is None or isinstance(expr.start, Nil):
start = ''
else:
start = self._print(expr.start)
if (expr.end is None) or isinstance(expr.end, Nil):
end = ''
else:
end = expr.end - 1
end = self._print(end)
return '{0}:{1}'.format(start, end)
def _pad_leading_columns(self, lines):
result = []
for line in lines:
if line.startswith('!'):
result.append("! " + line[1:].lstrip())
else:
result.append(line)
return result
def _wrap_fortran(self, lines):
"""Wrap long Fortran lines
Argument:
lines -- a list of lines (without \\n character)
A comment line is split at white space. Code lines are split with a more
complex rule to give nice results.
"""
# routine to find split point in a code line
my_alnum = set("_+-." + string.digits + string.ascii_letters)
my_white = set(" \t()")
def split_pos_code(line, endpos):
if len(line) <= endpos:
return len(line)
pos = endpos
split = lambda pos: \
(line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \
(line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \
(line[pos] in my_white and line[pos - 1] not in my_white) or \
(line[pos] not in my_white and line[pos - 1] in my_white)
while not split(pos):
pos -= 1
if pos == 0:
return endpos
return pos
# split line by line and add the splitted lines to result
result = []
trailing = ' &'
for line in lines:
if line.startswith("! "):
# comment line
if len(line) > 72:
pos = line.rfind(" ", 6, 72)
if pos == -1:
pos = 72
hunk = line[:pos]
line = line[pos:].lstrip()
result.append(hunk)
while len(line) > 0:
pos = line.rfind(" ", 0, 66)
if pos == -1 or len(line) < 66:
pos = 66
hunk = line[:pos]
line = line[pos:].lstrip()
result.append("%s%s" % ("! ", hunk))
else:
result.append(line)
elif (line[72:].count("'" )+line[72:].count('"'))%2-1:
# code line
pos = split_pos_code(line, 72)
hunk = line[:pos].rstrip()
line = line[pos:].lstrip()
if line:
hunk += trailing
result.append(hunk)
while len(line) > 0:
pos = split_pos_code(line, 65)
hunk = line[:pos].rstrip()
line = line[pos:].lstrip()
if line:
hunk += trailing
result.append("%s%s"%(" " , hunk))
else:
# we don't seperate lines in those cases mentioned above
#TODO improve find the postion of the caractere and split there
result.append(line)
return result
[docs] def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, string_types):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
code = [line.lstrip(' \t') for line in code]
inc_keyword = ('do ', 'if(', 'if ', 'do\n',
'else', 'type', 'subroutine', 'function')
dec_keyword = ('end do', 'enddo', 'end if', 'endif',
'else', 'endtype', 'end type',
'endfunction', 'end function',
'endsubroutine', 'end subroutine')
increase = [int(any(map(line.startswith, inc_keyword)))
for line in code]
decrease = [int(any(map(line.startswith, dec_keyword)))
for line in code]
continuation = [int(any(map(line.endswith, ['&', '&\n'])))
for line in code]
level = 0
cont_padding = 0
tabwidth = self._default_settings['tabwidth']
new_code = []
for i, line in enumerate(code):
if line == '' or line == '\n':
new_code.append(line)
continue
level -= decrease[i]
padding = " "*(level*tabwidth + cont_padding)
line = "%s%s" % (padding, line)
new_code.append(line)
if continuation[i]:
cont_padding = 2*tabwidth
else:
cont_padding = 0
level += increase[i]
return new_code
[docs]def fcode(expr, assign_to=None, **settings):
"""Converts an expr to a string of c code
expr : Expr
A sympy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
line-wrapping, or for expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=15].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, cfunction_string)]. See below
for examples.
"""
return FCodePrinter(settings).doprint(expr, assign_to)