# coding: utf-8
from sympy.core.basic import Basic
from sympy.core.function import Lambda
from sympy.core.symbol import Symbol
from sympy.core import Add, Mul, Pow, S
from sympy.core.compatibility import default_sort_key, string_types
from sympy.core.sympify import _sympify
from sympy.core.mul import _keep_coeff
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence
from pyccel.ast.core import Assign,DottedVariable
from pyccel.ast.core import FunctionDef
from pyccel.ast.core import ZerosLike
from pyccel.ast import Real
# TODO: add examples
__all__ = ["CodePrinter"]
[docs]class CodePrinter(StrPrinter):
"""
The base class for code-printing subclasses.
"""
_operators = {
'and': '&&',
'or': '||',
'not': '!',
}
[docs] def doprint(self, expr, assign_to=None):
"""
Print the expression as code.
expr : Expression
The expression to be printed.
assign_to : Symbol, MatrixSymbol, or string (optional)
If provided, the printed code will set the expression to a
variable with name ``assign_to``.
"""
if isinstance(assign_to, string_types):
assign_to = Symbol(assign_to)
elif not isinstance(assign_to, (Basic, type(None))):
raise TypeError("{0} cannot assign to object of type {1}".format(
type(self).__name__, type(assign_to)))
if assign_to:
expr = Assign(assign_to, expr)
else:
expr = _sympify(expr)
# Do the actual printing
lines = self._print(expr).splitlines()
# Format the output
return "\n".join(self._format_code(lines))
def _get_statement(self, codestring):
"""Formats a codestring with the proper line ending."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_comment(self, text):
"""Formats a text string as a comment."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _declare_number_const(self, name, value):
"""Declare a numeric constant at the top of a function"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _format_code(self, lines):
"""Take in a list of lines of code, and format them accordingly.
This may include indenting, wrapping long lines, etc..."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _print_NumberSymbol(self, expr):
return str(expr)
def _print_Dummy(self, expr):
# dummies must be printed as unique symbols
return "%s_%i" % (expr.name, expr.dummy_index) # Dummy
def _print_And(self, expr):
PREC = precedence(expr)
return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC)
for a in sorted(expr.args, key=default_sort_key))
def _print_Or(self, expr):
PREC = precedence(expr)
return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC)
for a in sorted(expr.args, key=default_sort_key))
def _print_Xor(self, expr):
if self._operators.get('xor') is None:
return self._print_not_supported(expr)
PREC = precedence(expr)
return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC)
for a in expr.args)
def _print_Equivalent(self, expr):
if self._operators.get('equivalent') is None:
return self._print_not_supported(expr)
PREC = precedence(expr)
return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC)
for a in expr.args)
def _print_Not(self, expr):
PREC = precedence(expr)
return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
def _print_Mul(self, expr):
prec = precedence(expr)
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = "-"
else:
sign = ""
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
if self.order not in ('old', 'none'):
args = expr.as_ordered_factors()
else:
# use make_args in case expr was something like -x -> x
args = Mul.make_args(expr)
# Gather args for numerator/denominator
flag = True
for item in args:
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
base = item.base
if base.is_integer and flag:
flag = False
base = Real(item.base)
#we only need to do it once
#to one of the denominator args
if item.exp != -1:
b.append(Pow(base, -item.exp, evaluate=False))
else:
b.append(Pow(base, -item.exp))
else:
a.append(item)
a = a or [S.One]
a_str = [self.parenthesize(x, prec) for x in a]
b_str = [self.parenthesize(x, prec) for x in b]
if len(b) == 0:
return sign + '*'.join(a_str)
elif len(b) == 1:
if len(a) == 1 and not (a[0].is_Atom or a[0].is_Add):
return sign + "%s/" % a_str[0] + '*'.join(b_str)
else:
return sign + '*'.join(a_str) + "/%s" % b_str[0]
else:
return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
def _print_not_supported(self, expr):
raise TypeError("{0} not supported in {1}".format(type(expr), self.language))
# Number constants
_print_Catalan = _print_NumberSymbol
_print_EulerGamma = _print_NumberSymbol
_print_GoldenRatio = _print_NumberSymbol
_print_Exp1 = _print_NumberSymbol
_print_Pi = _print_NumberSymbol
# The following can not be simply translated into C or Fortran
_print_Basic = _print_not_supported
_print_ComplexInfinity = _print_not_supported
_print_Derivative = _print_not_supported
_print_dict = _print_not_supported
_print_ExprCondPair = _print_not_supported
_print_GeometryEntity = _print_not_supported
_print_Infinity = _print_not_supported
_print_Integral = _print_not_supported
_print_Interval = _print_not_supported
_print_Limit = _print_not_supported
_print_list = _print_not_supported
_print_Matrix = _print_not_supported
_print_ImmutableMatrix = _print_not_supported
_print_MutableDenseMatrix = _print_not_supported
_print_MatrixBase = _print_not_supported
_print_DeferredVector = _print_not_supported
_print_NaN = _print_not_supported
_print_NegativeInfinity = _print_not_supported
_print_Normal = _print_not_supported
_print_Order = _print_not_supported
_print_PDF = _print_not_supported
_print_RootOf = _print_not_supported
_print_RootsOf = _print_not_supported
_print_RootSum = _print_not_supported
_print_Sample = _print_not_supported
_print_SparseMatrix = _print_not_supported
_print_tuple = _print_not_supported
_print_Uniform = _print_not_supported
_print_Unit = _print_not_supported
_print_Wild = _print_not_supported
_print_WildFunction = _print_not_supported