Skip to content

Commit

Permalink
Merge pull request #26583 from haru-44/classify_pde
Browse files Browse the repository at this point in the history
Refactoring of `classify_pde`
  • Loading branch information
sylee957 committed May 9, 2024
2 parents 2658461 + d2dbc51 commit 8f9d043
Showing 1 changed file with 17 additions and 32 deletions.
49 changes: 17 additions & 32 deletions sympy/solvers/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs):
if dict:
matching_hints["default"] = None
return matching_hints
else:
return ()
return ()

eq = expand(eq)

Expand All @@ -313,31 +312,20 @@ def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs):
n = Wild('n', exclude = [x, y])
# Try removing the smallest power of f(x,y)
# from the highest partial derivatives of f(x,y)
reduced_eq = None
reduced_eq = eq
if eq.is_Add:
var = set(combinations_with_replacement((x,y), order))
dummyvar = var.copy()
power = None
for i in var:
coeff = eq.coeff(f(x,y).diff(*i))
if coeff != 1:
match = coeff.match(a*f(x,y)**n)
if match and match[a]:
power = match[n]
dummyvar.remove(i)
break
dummyvar.remove(i)
for i in dummyvar:
for i in set(combinations_with_replacement((x,y), order)):
coeff = eq.coeff(f(x,y).diff(*i))
if coeff != 1:
match = coeff.match(a*f(x,y)**n)
if match and match[a] and match[n] < power:
if coeff == 1:
continue
match = coeff.match(a*f(x,y)**n)
if match and match[a]:
if power is None or match[n] < power:
power = match[n]
if power:
den = f(x,y)**power
reduced_eq = Add(*[arg/den for arg in eq.args])
if not reduced_eq:
reduced_eq = eq

if order == 1:
reduced_eq = collect(reduced_eq, f(x, y))
Expand All @@ -348,14 +336,12 @@ def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs):
## equation with constant coefficients
r.update({'b': b, 'c': c, 'd': d})
matching_hints["1st_linear_constant_coeff_homogeneous"] = r
else:
if r[b]**2 + r[c]**2 != 0:
## Linear first-order general partial-differential
## equation with constant coefficients
r.update({'b': b, 'c': c, 'd': d, 'e': e})
matching_hints["1st_linear_constant_coeff"] = r
matching_hints[
"1st_linear_constant_coeff_Integral"] = r
elif r[b]**2 + r[c]**2 != 0:
## Linear first-order general partial-differential
## equation with constant coefficients
r.update({'b': b, 'c': c, 'd': d, 'e': e})
matching_hints["1st_linear_constant_coeff"] = r
matching_hints["1st_linear_constant_coeff_Integral"] = r

else:
b = Wild('b', exclude=[f(x, y), fx, fy])
Expand All @@ -367,20 +353,19 @@ def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs):
matching_hints["1st_linear_variable_coeff"] = r

# Order keys based on allhints.
retlist = [i for i in allhints if i in matching_hints]
rettuple = tuple(i for i in allhints if i in matching_hints)

if dict:
# Dictionaries are ordered arbitrarily, so make note of which
# hint would come first for pdsolve(). Use an ordered dict in Py 3.
matching_hints["default"] = None
matching_hints["ordered_hints"] = tuple(retlist)
matching_hints["ordered_hints"] = rettuple
for i in allhints:
if i in matching_hints:
matching_hints["default"] = i
break
return matching_hints
else:
return tuple(retlist)
return rettuple


def checkpdesol(pde, sol, func=None, solve_for_func=True):
Expand Down

0 comments on commit 8f9d043

Please sign in to comment.