Skip to content

Commit

Permalink
Fix cse treatment in lambdify with Derivatives
Browse files Browse the repository at this point in the history
Before, when there were Derivatives in expr ans args given to lambdify
with cse enabled, there was an error because the cse treatment changed
the arguments of the Derivative object.

With this implementation, the expression is pre-treated by a function to
mask the instances of Derivative objects, then the cse process is
applied and finally we do a post-treatment to put back the Derivative
expressions in expr.
  • Loading branch information
mleila1312 committed Apr 10, 2024
1 parent e89ee93 commit 5e97b9d
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 5 deletions.
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,7 @@ latot <felipematas@yahoo.com>
luzpaz <luzpaz@users.noreply.github.com> luz paz <luzpaz@pm.me>
luzpaz <luzpaz@users.noreply.github.com> luz.paz <luzpaz@users.noreply.github.com>
mao8 <thisisma08@gmail.com>
mleila1312 <leila.iksil@gmail.com>
mohajain <mohajain99@gmail.com> mohajain <45903778+mohajain@users.noreply.github.com>
mohammedouahman <simofun85@gmail.com>
mohit <39158356+mohitacecode@users.noreply.github.com> mohit <42018918+mohitshah3111999@users.noreply.github.com>
Expand Down
161 changes: 156 additions & 5 deletions sympy/utilities/lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from sympy.utilities.exceptions import sympy_deprecation_warning
from sympy.utilities.decorator import doctest_depends_on
from sympy.utilities.iterables import (is_sequence, iterable,
NotIterable, flatten)
NotIterable, flatten, numbered_symbols,)
from sympy.utilities.misc import filldedent


__doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']}

# Default namespaces, letting us define translations that can't be defined
Expand Down Expand Up @@ -178,6 +179,151 @@ def _import(module, reload=False):
# linecache.
_lambdify_generated_counter = 1

def _replace_recursively(e, dict) :
if isinstance(e, list):
return [_replace_recursively(sub_e, dict)for sub_e in e]
elif isinstance(e, tuple):
return tuple([_replace_recursively(sub_e, dict)for sub_e in e])
else :
return e.xreplace(dict)

def _pre_treatment_cse(args_f, expr):
r"""
This function masks Derivative that are also arguments
of the expression to prevent erros in the cse treament
in lambdify.
The first step is to go through the expression to make
sure that we don't replace the Derivatives by symbols
already in the expression. Then we replace the
Derivatives in the expression with symbols and remember
the changes in a dictionary.
Parameters :
args_f : the arguments of expr given in lambdify
expr : expression given to lambdify
Return :
dictionary : dictionary of the associations
Derivative-new name
new_expr : expression where the Derivatives
have been replaced
"""
#Necessary librairies and dependencies
from sympy.core.function import Derivative
from sympy.core.symbol import Symbol
from sympy.core import Basic
from sympy.matrices.expressions import MatrixSymbol
from sympy.matrices.expressions.matexpr import MatrixElement
from sympy.polys.rootoftools import RootOf

#creation of the dictionary
dictionary={}
# creation of the symbols that can't be used to replace the Derivatives in the expression
excluded_symbols = set()
symbols = numbered_symbols(cls=Symbol)
def _eliminates_symbols(expr):
# function that finds the symbols that can't be used
if not isinstance(expr, Basic):
return

if isinstance(expr, RootOf):
return

if isinstance(expr, Basic) and (
expr.is_Atom or
expr.is_Order or
isinstance(expr, (MatrixSymbol, MatrixElement))):
if expr.is_Symbol:
excluded_symbols.add(expr.name)
return
#recursively goes through the expression
if iterable(expr):
args = expr
else:
args = expr.args
list(map(_eliminates_symbols, args))
return

if iterable(expr):
for e in expr:
if isinstance(e, Basic):
_eliminates_symbols(e)
else:
if isinstance(expr, Basic):
_eliminates_symbols(expr)

#gets the possible symbols to replace Derivatives with
symbols = (_ for _ in symbols if _.name not in excluded_symbols)
new_expr = expr
# replaces the instances of Derivatives in the expression

for arg in args_f:
if isinstance(arg, (Derivative)):
try:
dictionary[arg] = next(symbols)
except StopIteration:
raise ValueError("Symbols iterator ran out of symbols.")

new_expr=_replace_recursively(new_expr, dictionary)
return dictionary, new_expr

def _post_treatment_cse(dictionary, args, expr, cses):
r"""
This function changes back the replaced Derivatives to
their original values after passing through
_pre_treatment_cse and cse in lambdify.
This function returns the Derivatives to their
original value in the expression and cses.
Parameters :
dictionary : dictonary containing associations
of Derivative-new name given by _pre_treatment_cse
args : arguments given to lambdify of expr
expr : expression returned by cse
cses : changes made by the cse process containing
the associations partial expression- new name
Return :
post_cses : cses modified to return Derivatives
to their original value
post_expr : expression where the Derivatives have
been returned back to their original values
"""
from sympy.core.function import Derivative
post_expr = expr
post_cses= cses
for arg in args:
if isinstance(arg, Derivative):
association = []
#checks if if the new name of the Deivative was changed by the cse process
#or if combinations of the Derivatives expressions were replaces
for i in range(len(cses)):
new_a, a = cses[i]
if a.has(dictionary[arg]):
if a == dictionary[arg]:
association = new_a
post_cses.remove((new_a, a))
else :
a = a.xreplace({dictionary[arg] : arg})
cses[i] = new_a, a
#Checks if the new name of the Deivative was changed by the cse process
if association == []:
# if the derivative hasn't been replaced by the cse process
post_expr = _replace_recursively(post_expr, {dictionary[arg] : arg})
else:
# if the derivative has been replaced by the cse process
post_expr = _replace_recursively(post_expr,{association : arg})
return post_cses, post_expr


@doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,))
def lambdify(args, expr, modules=None, printer=None, use_imps=True,
Expand Down Expand Up @@ -277,7 +423,8 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
6
expr : Expr
An expression, list of expressions, or matrix to be evaluated.
An expression, list of expressions, tuple of expressions or
matrix to be evaluated.
Lists may be nested.
If the expression is a list, the output will also be a list.
Expand Down Expand Up @@ -756,7 +903,6 @@ def _lambdifygenerated(x):
"""
from sympy.core.symbol import Symbol
from sympy.core.expr import Expr

# If the user hasn't specified any modules, use what is available.
if modules is None:
try:
Expand Down Expand Up @@ -866,8 +1012,13 @@ def _lambdifygenerated(x):
funcprinter = _EvaluatorPrinter(printer, dummify)

if cse == True:
from sympy.simplify.cse_main import cse as _cse
cses, _expr = _cse(expr, list=False)
#get the dictionary containing the Derivative in the
#arguments and their new name
dictionary, new_expr= _pre_treatment_cse(args, expr)
from sympy.simplify.cse_main import cse as cse_function
cses, _expr = cse_function(new_expr, list=False)
#puts back the instances of Derivatives inthe expression
cses, _expr= _post_treatment_cse(dictionary, args, _expr, cses)
elif callable(cse):
cses, _expr = cse(expr)
else:
Expand Down
79 changes: 79 additions & 0 deletions sympy/utilities/tests/test_lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,3 +1891,82 @@ def test_assoc_legendre_numerical_evaluation():

assert all_close(sympy_result_integer, mpmath_result_integer, tol)
assert all_close(sympy_result_complex, mpmath_result_complex, tol)

def test_derivative_issue_26404():
r"""
test issue fixed when using cse in lambdify when some arguments
are Derivatives and they appear in the expression
"""
from sympy import (cos, sin, Matrix, symbols)
from sympy.physics.mechanics import (dynamicsymbols)
t = symbols("t")
x = Function("x")(t)
xd = x.diff(t)
xdd= xd.diff(t)
assert lambdify((xd, x), xd, cse=True)(1, 1) == 1
assert lambdify((xd, x), xd + x, cse=True)(1, 1) == 2
assert lambdify((xdd, xd, x), xdd*xd + x, cse=True)(3,1, 1) == 4
assert lambdify((xd, xdd, x), xdd*xd + x, cse=True)(3,1, 1) == 4
assert lambdify((xdd, xd, x), cos(xdd*xd) + x, cse=True)(0,1, 1) == 2.0
#test for matrix and cases were Derivative(a,b) becomes x_n
#and cse makes a replacement x_m: x_n**2 or other
#and case where xn(n : int) is already the name of an
#element of the function
x0, m0 = symbols("x0 m0")
l1, m1 = symbols("l1 m1")
m2 = symbols("m2")
g = symbols("g")
q0, q1, q2 = Function("q0")(x0),Function("q1")(l1),Function("q2")(m0)
u1, u2 =q1.diff(l1), q2.diff(m0)
F, T1 = dynamicsymbols("F T1")
massmatrix1 = Matrix([[m0 + m1 + m2, -x0*m1*cos(q1) - x0*m2*cos(q1),
-l1*m2*cos(q2)],
[-x0*m1*cos(q1) - x0*m2*cos(q1), x0**2*m1 + x0**2*m2,
x0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2))],
[-l1*m2*cos(q2),
x0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2)),
l1**2*m2]])

forcing1 = Matrix([[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) -
l1*m2*u2**2*sin(q2) + F,
g*x0*m1*sin(q1) + g*x0*m2*sin(q1) -
x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2,
g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) +
sin(q2)*cos(q1))*u1**2],
[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) -
l1*m2*u2**2*sin(q2) + F,
g*x0*m1*sin(q1) + g*x0*m2*sin(q1) -
x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2,
g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) +
sin(q2)*cos(q1))*u1**2],
[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) -
l1*m2*u2**2*sin(q2) + F,
g*x0*m1*sin(q1) + g*x0*m2*sin(q1) -
x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2,
g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) +
sin(q2)*cos(q1))*u1**2]])
res_expected=Matrix([[ 1., 0, 0],[-1., 0, 0],[-1., 0, 0]])
res_lamdbify=Matrix((lambdify((x0, m0 ,l1, m1, m2, g, q0, q1, q2, u1, u2, F, T1), massmatrix1 -forcing1, \
cse=True)( 0, 0 ,0, 1, 1, 1, 0, 1, 1 , 1, 1, 1, 1)))
equal=True
for i in range(res_lamdbify.rows*res_lamdbify.cols):
equal=equal and (res_expected[i]==res_lamdbify[i])
assert equal
# test in the case chen a list of expressions is given
expected=[[[0, 2, 18], 5], [18, 1], 0]
t1, t2, t3, t4, t5, t6, t7 = symbols("t1 t2 t3 t4 t5 t6 t7")
x1, x2, x3, x4, x5, x6, x7 = Function('x1')(t1), Function('x2')(t2), Function('x3')(t3),\
Function('x4')(t4), Function('x5')(t5), Function('x6')(t6),\
Function('x7')(t7)
d1, d2, d3, d4, d5, d6, d7 = x1.diff(t1), x2.diff(t2), x3.diff(t3), x4.diff(t4),\
x5.diff(t5), x6.diff(t6), x7.diff(t7)
list_of_list = [[[d5*d1, d3, d4*d7], d6], [d4*d7, d2],d5*d1]
res_list_of_list = lambdify((d1, d2, d3, d4, d5, d6, d7),list_of_list,\
cse=True)(0,1,2,3,4,5,6)
assert (expected[0][0][0] == res_list_of_list[0][0][0] and\
expected[0][0][1] == res_list_of_list[0][0][1] and\
expected[0][0][2] == res_list_of_list[0][0][2] and \
expected[0][1] == res_list_of_list[0][1] and\
expected[1][0] == res_list_of_list[1][0] and \
expected[1][1] == res_list_of_list[1][1] and \
expected[2] == res_list_of_list[2])

0 comments on commit 5e97b9d

Please sign in to comment.