Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre and post treatment of expression when cse=True in lambdify #26463

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,7 @@ latot <felipematas@yahoo.com>
luzpaz <luzpaz@users.noreply.github.com> luz paz <luzpaz@pm.me>
luzpaz <luzpaz@users.noreply.github.com> luz.paz <luzpaz@users.noreply.github.com>
mao8 <thisisma08@gmail.com>
mleila1312 <leila.iksil@gmail.com>
mohajain <mohajain99@gmail.com> mohajain <45903778+mohajain@users.noreply.github.com>
mohammedouahman <simofun85@gmail.com>
mohit <39158356+mohitacecode@users.noreply.github.com> mohit <42018918+mohitshah3111999@users.noreply.github.com>
Expand Down
161 changes: 156 additions & 5 deletions sympy/utilities/lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from sympy.utilities.exceptions import sympy_deprecation_warning
from sympy.utilities.decorator import doctest_depends_on
from sympy.utilities.iterables import (is_sequence, iterable,
NotIterable, flatten)
NotIterable, flatten, numbered_symbols,)
from sympy.utilities.misc import filldedent


__doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']}

# Default namespaces, letting us define translations that can't be defined
Expand Down Expand Up @@ -178,6 +179,151 @@ def _import(module, reload=False):
# linecache.
_lambdify_generated_counter = 1

def _replace_recursively(e, dict) :
if isinstance(e, list):
return [_replace_recursively(sub_e, dict)for sub_e in e]
elif isinstance(e, tuple):
return tuple([_replace_recursively(sub_e, dict)for sub_e in e])
else :
return e.xreplace(dict)

def _pre_treatment_cse(args_f, expr):
r"""
This function masks Derivative that are also arguments
of the expression to prevent erros in the cse treament
in lambdify.

The first step is to go through the expression to make
sure that we don't replace the Derivatives by symbols
already in the expression. Then we replace the
Derivatives in the expression with symbols and remember
the changes in a dictionary.

Parameters :
args_f : the arguments of expr given in lambdify

expr : expression given to lambdify
Return :
dictionary : dictionary of the associations
Derivative-new name

new_expr : expression where the Derivatives
have been replaced

"""
mleila1312 marked this conversation as resolved.
Show resolved Hide resolved
#Necessary librairies and dependencies
from sympy.core.function import Derivative
from sympy.core.symbol import Symbol
from sympy.core import Basic
from sympy.matrices.expressions import MatrixSymbol
from sympy.matrices.expressions.matexpr import MatrixElement
from sympy.polys.rootoftools import RootOf

#creation of the dictionary
dictionary={}
# creation of the symbols that can't be used to replace the Derivatives in the expression
excluded_symbols = set()
symbols = numbered_symbols(cls=Symbol)
def _eliminates_symbols(expr):
# function that finds the symbols that can't be used
if not isinstance(expr, Basic):
return

if isinstance(expr, RootOf):
return

if isinstance(expr, Basic) and (
expr.is_Atom or
expr.is_Order or
isinstance(expr, (MatrixSymbol, MatrixElement))):
if expr.is_Symbol:
excluded_symbols.add(expr.name)
return
#recursively goes through the expression
if iterable(expr):
args = expr
else:
args = expr.args
list(map(_eliminates_symbols, args))
return

if iterable(expr):
for e in expr:
if isinstance(e, Basic):
_eliminates_symbols(e)
else:
if isinstance(expr, Basic):
_eliminates_symbols(expr)

#gets the possible symbols to replace Derivatives with
symbols = (_ for _ in symbols if _.name not in excluded_symbols)
new_expr = expr
# replaces the instances of Derivatives in the expression

for arg in args_f:
if isinstance(arg, (Derivative)):
try:
dictionary[arg] = next(symbols)
except StopIteration:
raise ValueError("Symbols iterator ran out of symbols.")

new_expr=_replace_recursively(new_expr, dictionary)
return dictionary, new_expr

def _post_treatment_cse(dictionary, args, expr, cses):
r"""
This function changes back the replaced Derivatives to
their original values after passing through
_pre_treatment_cse and cse in lambdify.

This function returns the Derivatives to their
original value in the expression and cses.

Parameters :
dictionary : dictonary containing associations
of Derivative-new name given by _pre_treatment_cse

args : arguments given to lambdify of expr

expr : expression returned by cse

cses : changes made by the cse process containing
the associations partial expression- new name

Return :
post_cses : cses modified to return Derivatives
to their original value

post_expr : expression where the Derivatives have
been returned back to their original values

"""
from sympy.core.function import Derivative
post_expr = expr
post_cses= cses
for arg in args:
if isinstance(arg, Derivative):
association = []
#checks if if the new name of the Deivative was changed by the cse process
#or if combinations of the Derivatives expressions were replaces
for i in range(len(cses)):
new_a, a = cses[i]
if a.has(dictionary[arg]):
if a == dictionary[arg]:
association = new_a
post_cses.remove((new_a, a))
else :
a = a.xreplace({dictionary[arg] : arg})
cses[i] = new_a, a
#Checks if the new name of the Deivative was changed by the cse process
if association == []:
# if the derivative hasn't been replaced by the cse process
post_expr = _replace_recursively(post_expr, {dictionary[arg] : arg})
else:
# if the derivative has been replaced by the cse process
post_expr = _replace_recursively(post_expr,{association : arg})
return post_cses, post_expr


@doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,))
def lambdify(args, expr, modules=None, printer=None, use_imps=True,
Expand Down Expand Up @@ -277,7 +423,8 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
6

expr : Expr
An expression, list of expressions, or matrix to be evaluated.
An expression, list of expressions, tuple of expressions or
matrix to be evaluated.

Lists may be nested.
If the expression is a list, the output will also be a list.
Expand Down Expand Up @@ -756,7 +903,6 @@ def _lambdifygenerated(x):
"""
from sympy.core.symbol import Symbol
from sympy.core.expr import Expr

# If the user hasn't specified any modules, use what is available.
if modules is None:
try:
Expand Down Expand Up @@ -866,8 +1012,13 @@ def _lambdifygenerated(x):
funcprinter = _EvaluatorPrinter(printer, dummify)

if cse == True:
from sympy.simplify.cse_main import cse as _cse
cses, _expr = _cse(expr, list=False)
#get the dictionary containing the Derivative in the
#arguments and their new name
dictionary, new_expr= _pre_treatment_cse(args, expr)
from sympy.simplify.cse_main import cse as cse_function
cses, _expr = cse_function(new_expr, list=False)
#puts back the instances of Derivatives inthe expression
cses, _expr= _post_treatment_cse(dictionary, args, _expr, cses)
elif callable(cse):
cses, _expr = cse(expr)
else:
Expand Down
79 changes: 79 additions & 0 deletions sympy/utilities/tests/test_lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,3 +1891,82 @@ def test_assoc_legendre_numerical_evaluation():

assert all_close(sympy_result_integer, mpmath_result_integer, tol)
assert all_close(sympy_result_complex, mpmath_result_complex, tol)

def test_derivative_issue_26404():
r"""
test issue fixed when using cse in lambdify when some arguments
are Derivatives and they appear in the expression
"""
from sympy import (cos, sin, Matrix, symbols)
from sympy.physics.mechanics import (dynamicsymbols)
t = symbols("t")
x = Function("x")(t)
xd = x.diff(t)
xdd= xd.diff(t)
assert lambdify((xd, x), xd, cse=True)(1, 1) == 1
assert lambdify((xd, x), xd + x, cse=True)(1, 1) == 2
assert lambdify((xdd, xd, x), xdd*xd + x, cse=True)(3,1, 1) == 4
assert lambdify((xd, xdd, x), xdd*xd + x, cse=True)(3,1, 1) == 4
assert lambdify((xdd, xd, x), cos(xdd*xd) + x, cse=True)(0,1, 1) == 2.0
#test for matrix and cases were Derivative(a,b) becomes x_n
#and cse makes a replacement x_m: x_n**2 or other
#and case where xn(n : int) is already the name of an
#element of the function
x0, m0 = symbols("x0 m0")
l1, m1 = symbols("l1 m1")
m2 = symbols("m2")
g = symbols("g")
q0, q1, q2 = Function("q0")(x0),Function("q1")(l1),Function("q2")(m0)
u1, u2 =q1.diff(l1), q2.diff(m0)
F, T1 = dynamicsymbols("F T1")
massmatrix1 = Matrix([[m0 + m1 + m2, -x0*m1*cos(q1) - x0*m2*cos(q1),
-l1*m2*cos(q2)],
[-x0*m1*cos(q1) - x0*m2*cos(q1), x0**2*m1 + x0**2*m2,
x0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2))],
[-l1*m2*cos(q2),
x0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2)),
l1**2*m2]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These expressions would be more appropriate if the u values were all swapped with derivatives: u1 = Derivative(q1(t) t), for example.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u values are : u1, u2 =q1.diff(l1), q2.diff(m0), so already derivatives, did you mean putting more of them in the expression of the matrix? Or do you mean putting Derivatives directly in the matrix? If so, I haven't handled the case were the expression of a Derivative is directly in the expression and not also as an argument, I will add it


forcing1 = Matrix([[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) -
l1*m2*u2**2*sin(q2) + F,
g*x0*m1*sin(q1) + g*x0*m2*sin(q1) -
x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2,
g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) +
sin(q2)*cos(q1))*u1**2],
[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) -
l1*m2*u2**2*sin(q2) + F,
g*x0*m1*sin(q1) + g*x0*m2*sin(q1) -
x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2,
g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) +
sin(q2)*cos(q1))*u1**2],
[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) -
l1*m2*u2**2*sin(q2) + F,
g*x0*m1*sin(q1) + g*x0*m2*sin(q1) -
x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2,
g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) +
sin(q2)*cos(q1))*u1**2]])
res_expected=Matrix([[ 1., 0, 0],[-1., 0, 0],[-1., 0, 0]])
res_lamdbify=Matrix((lambdify((x0, m0 ,l1, m1, m2, g, q0, q1, q2, u1, u2, F, T1), massmatrix1 -forcing1, \
cse=True)( 0, 0 ,0, 1, 1, 1, 0, 1, 1 , 1, 1, 1, 1)))
equal=True
for i in range(res_lamdbify.rows*res_lamdbify.cols):
equal=equal and (res_expected[i]==res_lamdbify[i])
assert equal
# test in the case chen a list of expressions is given
expected=[[[0, 2, 18], 5], [18, 1], 0]
t1, t2, t3, t4, t5, t6, t7 = symbols("t1 t2 t3 t4 t5 t6 t7")
x1, x2, x3, x4, x5, x6, x7 = Function('x1')(t1), Function('x2')(t2), Function('x3')(t3),\
Function('x4')(t4), Function('x5')(t5), Function('x6')(t6),\
Function('x7')(t7)
d1, d2, d3, d4, d5, d6, d7 = x1.diff(t1), x2.diff(t2), x3.diff(t3), x4.diff(t4),\
x5.diff(t5), x6.diff(t6), x7.diff(t7)
list_of_list = [[[d5*d1, d3, d4*d7], d6], [d4*d7, d2],d5*d1]
res_list_of_list = lambdify((d1, d2, d3, d4, d5, d6, d7),list_of_list,\
cse=True)(0,1,2,3,4,5,6)
assert (expected[0][0][0] == res_list_of_list[0][0][0] and\
expected[0][0][1] == res_list_of_list[0][0][1] and\
expected[0][0][2] == res_list_of_list[0][0][2] and \
expected[0][1] == res_list_of_list[0][1] and\
expected[1][0] == res_list_of_list[1][0] and \
expected[1][1] == res_list_of_list[1][1] and \
expected[2] == res_list_of_list[2])