Skip to content

Commit

Permalink
get new commits in sympy main
Browse files Browse the repository at this point in the history
  • Loading branch information
mleila1312 committed Apr 10, 2024
2 parents f284375 + e89ee93 commit c17ad92
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
23 changes: 17 additions & 6 deletions sympy/functions/elementary/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sympy.core.numbers import Integer, Rational, pi, I
from sympy.core.parameters import global_parameters
from sympy.core.power import Pow
from sympy.core.relational import Ge
from sympy.core.singleton import S
from sympy.core.symbol import Wild, Dummy
from sympy.core.sympify import sympify
Expand Down Expand Up @@ -1161,20 +1162,30 @@ def eval(cls, x, k=None):
return S.Zero
if x is S.Exp1:
return S.One
if x == -1/S.Exp1:
return S.NegativeOne
w = Wild('w')
# W(x*log(x)) = log(x) for x >= 1/e
# e.g., W(-1/e) = -1, W(2*log(2)) = log(2)
result = x.match(w*log(w))
if result is not None and Ge(result[w]*S.Exp1, S.One) is S.true:
return log(result[w])
if x == -log(2)/2:
return -log(2)
if x == 2*log(2):
return log(2)
# W(x**(x+1)*log(x)) = x*log(x) for x > 0
# e.g., W(81*log(3)) = 3*log(3)
result = x.match(w**(w+1)*log(w))
if result is not None and result[w].is_positive is True:
return result[w]*log(result[w])
# W(e**(1/n)/n) = 1/n
# e.g., W(sqrt(e)/2) = 1/2
result = x.match(S.Exp1**(1/w)/w)
if result is not None:
return 1 / result[w]
if x == -pi/2:
return I*pi/2
if x == exp(1 + S.Exp1):
return S.Exp1
if x is S.Infinity:
return S.Infinity
if x.is_zero:
return S.Zero

if fuzzy_not(k.is_zero):
if x.is_zero:
Expand Down
4 changes: 4 additions & 0 deletions sympy/functions/elementary/tests/test_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,10 @@ def test_lambertw():
assert LambertW(0) == 0
assert LambertW(E) == 1
assert LambertW(-1/E) == -1
assert LambertW(100*log(100)) == log(100)
assert LambertW(-log(2)/2) == -log(2)
assert LambertW(81*log(3)) == 3*log(3)
assert LambertW(sqrt(E)/2) == S.Half
assert LambertW(oo) is oo
assert LambertW(0, 1) is -oo
assert LambertW(0, 42) is -oo
Expand All @@ -627,6 +630,7 @@ def test_lambertw():
assert LambertW(2, evaluate=False).is_real
p = Symbol('p', positive=True)
assert LambertW(p, evaluate=False).is_real
assert LambertW(p**(p+1)*log(p)) == p*log(p)
assert LambertW(p - 1, evaluate=False).is_real is None
assert LambertW(-p - 2/S.Exp1, evaluate=False).is_real is False
assert LambertW(S.Half, -1, evaluate=False).is_real is False
Expand Down
25 changes: 21 additions & 4 deletions sympy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, \
bsgs_direct_product, canonicalize, riemann_bsgs
from sympy.core import Basic, Expr, sympify, Add, Mul, S
from sympy.core.cache import clear_cache
from sympy.core.containers import Tuple, Dict
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Symbol, symbols
Expand Down Expand Up @@ -902,6 +903,9 @@ def set_comm(self, i, j, c):
if c not in (0, 1, None):
raise ValueError('`c` can assume only the values 0, 1 or None')

i = sympify(i)
j = sympify(j)

if i not in self._comm_symbols2i:
n = len(self._comm)
self._comm.append({})
Expand All @@ -921,6 +925,14 @@ def set_comm(self, i, j, c):
self._comm[ni][nj] = c
self._comm[nj][ni] = c

"""
Cached sympy functions (e.g. expand) may have cached the results of
expressions involving tensors, but those results may not be valid after
changing the commutation properties. To stay on the safe side, we clear
the cache of all functions.
"""
clear_cache()

def set_comms(self, *args):
"""
Set the commutation group numbers ``c`` for symbols ``i, j``.
Expand Down Expand Up @@ -1805,8 +1817,7 @@ def __new__(cls, name, index_types, symmetry=None, comm=0):
else:
assert symmetry.rank == len(index_types)

obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry)
obj.comm = TensorManager.comm_symbols2i(comm)
obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry, sympify(comm))
return obj

@property
Expand All @@ -1821,6 +1832,10 @@ def index_types(self):
def symmetry(self):
return self.args[2]

@property
def comm(self):
return TensorManager.comm_symbols2i(self.args[3])

@property
def rank(self):
return len(self.index_types)
Expand Down Expand Up @@ -4262,11 +4277,13 @@ def __new__(cls, name, index_types=None, symmetry=None, comm=0, unordered_indic
raise NotImplementedError("Wild matching based on symmetry is not implemented.")

obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), sympify(symmetry), sympify(comm), sympify(unordered_indices))
obj.comm = TensorManager.comm_symbols2i(comm)
obj.unordered_indices = unordered_indices

return obj

@property
def unordered_indices(self):
return self.args[4]

def __call__(self, *indices, **kwargs):
tensor = WildTensor(self, indices, **kwargs)
return tensor.doit()
Expand Down

0 comments on commit c17ad92

Please sign in to comment.