Skip to content

Commit

Permalink
Merge pull request #23976 from oscarbenjamin/pr_lambdify_cse_111
Browse files Browse the repository at this point in the history
fix(lambdify): make cse=True work with non-lists (1.11 branch)
  • Loading branch information
oscarbenjamin committed Aug 27, 2022
2 parents 26f7bdb + b63902e commit 41d9095
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Expand Up @@ -16,7 +16,7 @@ on:
- '1.11'
env:
release_branch: '1.11'
release_version: '1.11'
release_version: '1.11.1'
previous_version: '1.10.1'

jobs:
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/runtests.yml
Expand Up @@ -197,9 +197,11 @@ jobs:
- if: ${{ contains(matrix.python-version, '3.11') }}
run: pip install git+https://github.com/matplotlib/matplotlib@main

# SciPy 1.9.0 fails to compile under PyPy
# SciPy 1.9.0 and 1.9.1 fail to compile under PyPy:
#
# https://github.com/FFY00/meson-python/issues/121
- if: ${{ matrix.python-version == 'pypy-3.8' }}
run: pip install scipy!=1.9.0
run: pip install 'scipy<1.9.0'

# dependencies to install in all Python versions:
- run: pip install mpmath numpy numexpr matplotlib ipython cython scipy \
Expand Down
2 changes: 1 addition & 1 deletion sympy/release.py
@@ -1 +1 @@
__version__ = "1.11"
__version__ = "1.11.1"
11 changes: 2 additions & 9 deletions sympy/utilities/lambdify.py
Expand Up @@ -1113,16 +1113,9 @@ def doprint(self, funcname, args, expr, *, cses=()):

if cses:
subvars, subexprs = zip(*cses)
try:
exprs = expr + list(subexprs)
except TypeError:
try:
exprs = expr + tuple(subexprs)
except TypeError:
expr = [expr]
exprs = expr + list(subexprs)
exprs = [expr] + list(subexprs)
argstrs, exprs = self._preprocess(args, exprs)
expr, subexprs = exprs[:len(expr)], exprs[len(expr):]
expr, subexprs = exprs[0], exprs[1:]
cses = zip(subvars, subexprs)
else:
argstrs, expr = self._preprocess(args, expr)
Expand Down
4 changes: 3 additions & 1 deletion sympy/utilities/tests/test_lambdify.py
Expand Up @@ -1637,11 +1637,13 @@ def test_deprecated_set():
with warns_deprecated_sympy():
lambdify({x, y}, x + y)


def test_23536_lambdify_cse_dummy():

f = Function('x')(y)
g = Function('w')(y)
expr = z + (f**4 + g**5)*(f**3 + (g*f)**3)
expr = expr.expand()
eval_expr = lambdify(((f, g), z), expr, cse=True)
eval_expr((1.0, 2.0), 3.0)
ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError
assert ans == 300.0 # not a list and value is 300

0 comments on commit 41d9095

Please sign in to comment.