Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crashes with imaginary numbers #6

Open
thomasaarholt opened this issue Aug 17, 2022 · 2 comments
Open

Crashes with imaginary numbers #6

thomasaarholt opened this issue Aug 17, 2022 · 2 comments

Comments

@thomasaarholt
Copy link

I tried converting my complex sympy expression to jax, and got the following error.

I wrote a minimum working example. The I is sympy's variable for a complex number. 1j is Python's version, and they are both treated the same.

from sympy import symbols, I
import sympy2jax

x = symbols("x")

expr = x*I # or x*1j

sympy2jax.SymbolicModule(expr)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    212 try:
--> 213     return memodict[expr]
    214 except KeyError:

KeyError: I*x

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    212 try:
--> 213     return memodict[expr]
    214 except KeyError:

KeyError: I

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:180, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    179 try:
--> 180     self._func = func_lookup[expr.func]
    181 except KeyError as e:

KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/Users/thomas/Documents/vilde.ipynb Cell 6 in <cell line: 8>()
      [4](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=3) x = symbols("x")
      [6](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=5) expr = x*I # or x*1j
----> [8](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=7) sympy2jax.SymbolicModule(expr)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:257, in SymbolicModule.__init__(self, expressions, extra_funcs, make_array, **kwargs)
    250     self.has_extra_funcs = True
    251 _convert = ft.partial(
    252     _sympy_to_node,
    253     memodict=dict(),
    254     func_lookup=lookup,
    255     make_array=make_array,
    256 )
--> 257 self.nodes = jax.tree_map(_convert, expressions)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in tree_map(f, tree, is_leaf, *rest)
    203 leaves, treedef = tree_flatten(tree, is_leaf)
    204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in <genexpr>(.0)
    203 leaves, treedef = tree_flatten(tree, is_leaf)
    204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    222     out = _Rational(expr, make_array)
    223 else:
--> 224     out = _Func(expr, memodict, func_lookup, make_array)
    225 memodict[expr] = out
    226 return out

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:183, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    181 except KeyError as e:
    182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
--> 183 self._args = [
    184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:184, in <listcomp>(.0)
    181 except KeyError as e:
    182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
    183 self._args = [
--> 184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    222     out = _Rational(expr, make_array)
    223 else:
--> 224     out = _Func(expr, memodict, func_lookup, make_array)
    225 memodict[expr] = out
    226 return out

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:182, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    180     self._func = func_lookup[expr.func]
    181 except KeyError as e:
--> 182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
    183 self._args = [
    184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

KeyError: "Unsupported Sympy type <class 'sympy.core.numbers.ImaginaryUnit'>"
@thomasaarholt
Copy link
Author

I should add - it's absolutely fine if this is not fixed. I tried sympy2jax just for fun now, since I had already written my (rather large) expression in sympy.

@patrick-kidger
Copy link
Owner

So it should be pretty straightforward to add support. I'd be happy to accept a pull request on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants