From a3042a0dbadbc9033298e308d0c6d2689c351cce Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 31 Mar 2023 18:51:04 -0600 Subject: [PATCH] Avoid exp rewrites that evaluate back to Pow in solve() This can lead to infinite recursion. Fixes #24368 --- sympy/solvers/solvers.py | 2 ++ sympy/solvers/tests/test_solvers.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index fcc7e13c3179..3a86454b242f 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -42,6 +42,7 @@ separatevars) from sympy.simplify.sqrtdenest import sqrt_depth from sympy.simplify.fu import TR1, TR2i +from sympy.strategies.rl import rebuild from sympy.matrices.common import NonInvertibleMatrixError from sympy.matrices import Matrix, zeros from sympy.polys import roots, cancel, factor, Poly @@ -2770,6 +2771,7 @@ def equal(expr1, expr2): return _vsolve(lhs.args[0] - rhs*exp(rhs), sym, **flags) rewrite = lhs.rewrite(exp) + rewrite = rebuild(rewrite) # avoid rewrites involving evaluate=False if rewrite != lhs: return _vsolve(rewrite - rhs, sym, **flags) except NotImplementedError: diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index c13aa051c947..1c4b2fca8fa7 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -2,6 +2,7 @@ from sympy.core.add import Add from sympy.core.containers import Tuple from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mod import Mod from sympy.core.mul import Mul from sympy.core import (GoldenRatio, TribonacciConstant) from sympy.core.numbers import (E, Float, I, Rational, oo, pi) @@ -13,6 +14,7 @@ from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, re) from sympy.functions.elementary.exponential import (LambertW, exp, log) from sympy.functions.elementary.hyperbolic import (atanh, cosh, sinh, tanh) +from sympy.functions.elementary.integers import floor from sympy.functions.elementary.miscellaneous import (cbrt, root, sqrt) from sympy.functions.elementary.piecewise import Piecewise from sympy.functions.elementary.trigonometric import (acos, asin, atan, atan2, cos, sec, sin, tan) @@ -2646,3 +2648,11 @@ def test_solve_undetermined_coeffs_issue_23927(): phi: 2*atan((A + sqrt(A**2 + B**2))/B), r: (A**2 + A*sqrt(A**2 + B**2) + B**2)/(A + sqrt(A**2 + B**2))/-1 }] + +def test_issue_24368(): + # Ideally these would produce a solution, but for now just check that they + # don't fail with a RuntimeError + raises(NotImplementedError, lambda: solve(Mod(x**2, 49), x)) + s2 = Symbol('s2', integer=True, positive=True) + f = floor(s2/2 - S(1)/2) + raises(NotImplementedError, lambda: solve((Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1))*f + Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1), s2))