Skip to content

Commit

Permalink
Merge pull request #26450 from Emilius12/emile
Browse files Browse the repository at this point in the history
Add convexity check for multivariate functions
  • Loading branch information
smichr committed May 15, 2024
2 parents 521193b + 9eb139e commit 88b3aed
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 41 deletions.
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ Elias Basler <e.e.basler@protonmail.com>
Elisha Hollander <just4now666666@gmail.com> donno2048 <just4now666666@gmail.com>
Elliot Marshall <Marshall2389@gmail.com> <marshall2389@gmail.com>
Elrond der Elbenfuerst <elrond+sympy.org@samba-tng.org>
Emile Fourcini <emile.fourcin1@gmail.com> Emile <emile.fourcin1@gmail.com>
Emma Hogan <ehogan@gemini.edu>
Enric Florit <efz1005@gmail.com>
Eric Demer <demer@mailbox.org>
Expand Down
2 changes: 1 addition & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -1266,4 +1266,4 @@ Augusto Borges <borges.augustoar@gmail.com>
Han Wei Ang <ang.h.w@u.nus.edu>
Pablo <48098178+PabloRuizCuevas@users.noreply.github.com>
Congxu Yang <u7189828@anu.edu.au>
Saicharan <62512681+saicharan2804@users.noreply.github.com>
Saicharan <62512681+saicharan2804@users.noreply.github.com>
51 changes: 18 additions & 33 deletions sympy/calculus/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sympy.core.numbers import (E, I, Rational, oo, pi)
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, symbols)
from sympy.core.symbol import (Dummy, Symbol)
from sympy.functions.elementary.complexes import (Abs, re)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.integers import frac
Expand All @@ -23,12 +23,11 @@
from sympy.sets.fancysets import ImageSet
from sympy.sets.conditionset import ConditionSet
from sympy.testing.pytest import XFAIL, raises, _both_exp_pow, slow
from sympy.abc import x
from sympy.abc import x, y

a = Symbol('a', real=True)

def test_function_range():
x, y, a, b = symbols('x y a b')
assert function_range(sin(x), x, Interval(-pi/2, pi/2)
) == Interval(-1, 1)
assert function_range(sin(x), x, Interval(0, pi)
Expand Down Expand Up @@ -73,7 +72,6 @@ def test_function_range1():


def test_continuous_domain():
x = Symbol('x')
assert continuous_domain(sin(x), x, Interval(0, 2*pi)) == Interval(0, 2*pi)
assert continuous_domain(tan(x), x, Interval(0, 2*pi)) == \
Union(Interval(0, pi/2, False, True), Interval(pi/2, pi*Rational(3, 2), True, True),
Expand Down Expand Up @@ -186,10 +184,6 @@ def test_not_empty_in():

@_both_exp_pow
def test_periodicity():
x = Symbol('x')
y = Symbol('y')
z = Symbol('z', real=True)

assert periodicity(sin(2*x), x) == pi
assert periodicity((-2)*tan(4*x), x) == pi/4
assert periodicity(sin(x)**2, x) == 2*pi
Expand Down Expand Up @@ -221,14 +215,14 @@ def test_periodicity():

assert periodicity(exp(x), x) is None
assert periodicity(exp(I*x), x) == 2*pi
assert periodicity(exp(I*z), z) == 2*pi
assert periodicity(exp(z), z) is None
assert periodicity(exp(log(sin(z) + I*cos(2*z)), evaluate=False), z) == 2*pi
assert periodicity(exp(log(sin(2*z) + I*cos(z)), evaluate=False), z) == 2*pi
assert periodicity(exp(sin(z)), z) == 2*pi
assert periodicity(exp(2*I*z), z) == pi
assert periodicity(exp(z + I*sin(z)), z) is None
assert periodicity(exp(cos(z/2) + sin(z)), z) == 4*pi
assert periodicity(exp(I*a), a) == 2*pi
assert periodicity(exp(a), a) is None
assert periodicity(exp(log(sin(a) + I*cos(2*a)), evaluate=False), a) == 2*pi
assert periodicity(exp(log(sin(2*a) + I*cos(a)), evaluate=False), a) == 2*pi
assert periodicity(exp(sin(a)), a) == 2*pi
assert periodicity(exp(2*I*a), a) == pi
assert periodicity(exp(a + I*sin(a)), a) is None
assert periodicity(exp(cos(a/2) + sin(a)), a) == 4*pi
assert periodicity(log(x), x) is None
assert periodicity(exp(x)**sin(x), x) is None
assert periodicity(sin(x)**y, y) is None
Expand Down Expand Up @@ -261,9 +255,6 @@ def test_periodicity():


def test_periodicity_check():
x = Symbol('x')
y = Symbol('y')

assert periodicity(tan(x), x, check=True) == pi
assert periodicity(sin(x) + cos(x), x, check=True) == 2*pi
assert periodicity(sec(x), x) == 2*pi
Expand All @@ -285,13 +276,13 @@ def test_is_convex():
assert is_convex(x**2, x, domain=Interval(0, oo)) == True
assert is_convex(1/x**3, x, domain=Interval.Lopen(0, oo)) == True
assert is_convex(-1/x**3, x, domain=Interval.Ropen(-oo, 0)) == True
assert is_convex(log(x), x) == False
raises(NotImplementedError, lambda: is_convex(log(x), x, a))
assert is_convex(log(x) ,x) == False
assert is_convex(x**2+y**2, x, y) == True
assert is_convex(cos(x) + cos(y), x) == False
assert is_convex(8*x**2 - 2*y**2, x, y) == False


def test_stationary_points():
x, y = symbols('x y')

assert stationary_points(sin(x), x, Interval(-pi/2, pi/2)
) == {-pi/2, pi/2}
assert stationary_points(sin(x), x, Interval.Ropen(0, pi/4)
Expand Down Expand Up @@ -324,7 +315,6 @@ def test_stationary_points():


def test_maximum():
x, y = symbols('x y')
assert maximum(sin(x), x) is S.One
assert maximum(sin(x), x, Interval(0, 1)) == sin(1)
assert maximum(tan(x), x) is oo
Expand Down Expand Up @@ -357,8 +347,6 @@ def test_maximum():


def test_minimum():
x, y = symbols('x y')

assert minimum(sin(x), x) is S.NegativeOne
assert minimum(sin(x), x, Interval(1, 4)) == sin(4)
assert minimum(tan(x), x) is -oo
Expand Down Expand Up @@ -386,22 +374,19 @@ def test_minimum():


def test_issue_19869():
t = symbols('t')
assert (maximum(sqrt(3)*(t - 1)/(3*sqrt(t**2 + 1)), t)
assert (maximum(sqrt(3)*(x - 1)/(3*sqrt(x**2 + 1)), x)
) == sqrt(3)/3


def test_issue_16469():
x = Symbol("x", real=True)
f = abs(x)
assert function_range(f, x, S.Reals) == Interval(0, oo, False, True)
f = abs(a)
assert function_range(f, a, S.Reals) == Interval(0, oo, False, True)


@_both_exp_pow
def test_issue_18747():
assert periodicity(exp(pi*I*(x/4+S.Half/2)), x) == 8
assert periodicity(exp(pi*I*(x/4 + S.Half/2)), x) == 8


def test_issue_25942():
x = Symbol("x")
assert (acos(x) > pi/3).as_set() == Interval.Ropen(-1, S(1)/2)
10 changes: 3 additions & 7 deletions sympy/calculus/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sympy.sets.conditionset import ConditionSet
from sympy.utilities import filldedent
from sympy.utilities.iterables import iterable
from sympy.matrices.dense import hessian


def continuous_domain(f, symbol, domain):
Expand Down Expand Up @@ -745,18 +746,13 @@ def is_convex(f, *syms, domain=S.Reals):
.. [5] https://en.wikipedia.org/wiki/Concave_function
"""

if len(syms) > 1:
raise NotImplementedError(
"The check for the convexity of multivariate functions is not implemented yet.")

if len(syms) > 1 :
return hessian(f, syms).is_positive_semidefinite
from sympy.solvers.inequalities import solve_univariate_inequality

f = _sympify(f)
var = syms[0]
if any(s in domain for s in singularities(f, var)):
return False

condition = f.diff(var, 2) < 0
if solve_univariate_inequality(condition, var, False, domain):
return False
Expand Down

0 comments on commit 88b3aed

Please sign in to comment.