Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambdify doesn't recognize derivative symbol if cse is enabled #26404

Open
tjstienstra opened this issue Mar 25, 2024 · 12 comments · May be fixed by #26463
Open

Lambdify doesn't recognize derivative symbol if cse is enabled #26404

tjstienstra opened this issue Mar 25, 2024 · 12 comments · May be fixed by #26463

Comments

@tjstienstra
Copy link
Contributor

tjstienstra commented Mar 25, 2024

Here is a minimal reproducer:

>>> import sympy as sm   
>>> t = sm.symbols("t")
>>> x = sm.Function("x")(t)
>>> xd = x.diff(t)
>>> sm.lambdify((xd, x), xd + x)(1, 1)
2
>>> sm.lambdify((xd, x), xd, cse=True)(1, 1)     
1
>>> sm.lambdify((xd, x), xd + x, cse=True)(1, 1) 
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "...\sympy\utilities\lambdify.py", line 875, in lambdify
    funcstr = funcprinter.doprint(funcname, iterable_args, _expr, cses=cses)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\sympy\utilities\lambdify.py", line 1166, in doprint
    str_expr = _recursive_to_string(self._exprrepr, expr)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\sympy\utilities\lambdify.py", line 961, in _recursive_to_string
    return doprint(arg)
           ^^^^^^^^^^^^
  File "...\sympy\printing\codeprinter.py", line 172, in doprint
    lines = self._print(expr).splitlines()
            ^^^^^^^^^^^^^^^^^
  File "...\sympy\printing\printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\sympy\printing\str.py", line 57, in _print_Add
    t = self._print(term)
        ^^^^^^^^^^^^^^^^^
  File "...\sympy\printing\printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\sympy\printing\codeprinter.py", line 582, in _print_not_supported
    raise PrintMethodNotImplementedError("Unsupported by %s: %s" % (str(type(self)), str(type(expr))) + \
sympy.printing.codeprinter.PrintMethodNotImplementedError: Unsupported by <class 'sympy.printing.numpy.SciPyPrinter'>: <class 'sympy.core.function.Derivative'>
Set the printer option 'strict' to False in order to generate partially printed code.

I am not sure if this issue is easily solvable as the problem seems to be that it tries to print the expression x0 + Derivative(x0, t) with the [(x0, x(t))] as subexpressions.

@moorepants
Copy link
Member

moorepants commented Mar 25, 2024

I thought this would work, but it doesn't:

In [1]: import sympy as sm

In [2]: t = sm.symbols("t")

In [3]: x = sm.Function("x")(t)

In [4]: xd = x.diff(t)

In [5]: sm.lambdify((xd, x), xd + x, cse=True, dummify=True)(1, 1)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 sm.lambdify((xd, x), xd + x, cse=True, dummify=True)(1, 1) 

File <lambdifygenerated-1>:5, in _lambdifygenerated(_Dummy_36, _Dummy_37)
      1 def _lambdifygenerated(_Dummy_36, _Dummy_37):
      2     x0 = _Dummy_37
      3     return (  # Not supported in Python with SciPy and NumPy:
      4   # Derivative
----> 5 x0 + Derivative(x0, t))

NameError: name 'Derivative' is not defined

@moorepants
Copy link
Member

cse is called on the expressions and then the cse results are passed to the printer which runs a preprocessor on the results that do the dummifying. I'm guessing those operations should be swapped for things to work correctly.

@mleila1312
Copy link

Hello,

I've looked a bit into the code, I think the issue is specific to using xd = x.diff(t). This case is not tested(referring to test_lambdify.py) and with the current implementation, the Dummy_XY are not placed correctly in the expression calculated by cse since you would need to look recursively and replace recursively in the Derivative(x0, t) when x0 is replaced with a dummy (although looking recursively might not be needed if we could implement a simple way to link them) .

I'll look into it more and indeed, if we call cse after the printer it works perfectly, but then the results are not used.

@moorepants
Copy link
Member

The dummy replacement works recursively already. Maybe just look into having the preprocessor apply the dummy subs to all of the cse output.

@moorepants
Copy link
Member

Actually this code looks like the preprocessor is applied to all outputs of cse:

https://github.com/sympy/sympy/blob/master/sympy/utilities/lambdify.py#L1133-L1140

@moorepants
Copy link
Member

moorepants commented Apr 4, 2024

It looks like in the CSE process, which happens first, the argument of the derivative gets swapped out:

ipdb> expr
[x0 + Derivative(x0, t), x(t)]

This would then cause the dummy replacement to be missed because it is looking for Derivative(x, t) (probably). So making the dummy replacement happen first might be better.

In [3]: x + xd
Out[3]: x(t) + Derivative(x(t), t)

In [4]: sm.cse(x + xd)
Out[4]: ([(x0, x(t))], [x0 + Derivative(x0, t)])

@moorepants
Copy link
Member

One option would be a flag to cse() to have it skip Derivative(x, t) type terms.

@moorepants
Copy link
Member

moorepants commented Apr 4, 2024

Maybe a pre and post optimization could be used for cse, something like:

In [27]: def pre(expr): return expr.xreplace({xd: sm.Symbol('a')})

In [28]: def post(expr): return expr.xreplace({sm.Symbol('a'): xd})

In [29]: sm.cse(x + xd, optimizations=[(pre, post)])
Out[29]: ([], [x(t) + Derivative(x(t), t)])

In [31]: sm.cse(x + x*x + xd, optimizations=[(pre, post)])
Out[31]: ([(x0, x(t))], [x0**2 + x0 + Derivative(x(t), t)])

@mleila1312
Copy link

mleila1312 commented Apr 4, 2024

This would then cause the dummy replacement to be missed because it is looking for Derivative(x, t) (probably).

That's what I ended up thinking to

One option would be a flag to cse() to have it skip Derivative(x, t) type terms.

A flag could work, I was also thinking that we could consider the derivative (Derivative(x0, t)) like a replacement but I think the flag would be more practical.

Maybe a pre and post optimization could be used for cse

It would work to, but what are the advantages to this compared to the flag possibility?(since with this possibility, we would go trhough the expression twice and still ignore the derivatives)

@moorepants
Copy link
Member

It would work to, but what are the advantages to this compared to the flag possibility?(since with this possibility, we would go trhough the expression twice and still ignore the derivatives)

It is a flag that is already present, so we would not need to implement a new kwarg in cse.

@mleila1312
Copy link

mleila1312 commented Apr 4, 2024

It is a flag that is already present, so we would not need to implement a new kwarg in cse.

But isn't there some cases where we want the arguments of the derivative also changed accordingly?( instead of Derivative(x,t), we would want Derivative(x0,t))

I've been trying to implement the change(not using pre-prost method but a flad indicating cse is called inside lambdify), but I'm running into a difficulty and I don't know how to resolve it.

I've added an optionnal parameter in cse that indicates that it is called inside a lambdify function. When I simply used it as a non-optionnal parameter (put right after expr), the function worked perfectly and the bug was resolved. However, when put as an optionnal parameter, the value of flag_lambdify(name of the flag) isn't modified, and I've noticed that it is the same for the optionnal parameter list.

To see this problem, add the new flag variable, add a print(list, flag_lambdify) under the header of the function cse and enter the lines of code above that produced the bug. The print is also executed 2 times( you'll have a first print with the right values, and a second one with the default values).

If you want, you can also see the modified files in my sympy repo(although I'm not sure I left the print(list, flag_lambdify) when I pushed).

Also, I've not put the flag_lamdbify as a non-optionnal parameter because I think it is used in other files, but if it is not it would solve the problem( but not the problem with the list variable)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
@moorepants @tjstienstra @mleila1312 and others