Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class' #269

Open
siamak-attarian opened this issue Jun 29, 2023 · 0 comments

Comments

@siamak-attarian
Copy link

siamak-attarian commented Jun 29, 2023

Hi

I'm working on pop os 22.04 and first installed jax through anaconda (jax version 0.4.13, jaxlib version 0.4.13) and then installed jax-md. when importing jax-md I get the following error:

Traceback (most recent call last):
File "", line 1, in
File "/home/siamak/Downloads/jax-md/jax_md/init.py", line 16, in
from jax_md import energy
File "/home/siamak/Downloads/jax-md/jax_md/energy.py", line 30, in
from jax_md import space, smap, partition, nn, quantity, interpolate, util
File "/home/siamak/Downloads/jax-md/jax_md/nn.py", line 36, in
from ._nn import behler_parrinello
File "/home/siamak/Downloads/jax-md/jax_md/_nn/init.py", line 16, in
from . import nequip
File "/home/siamak/Downloads/jax-md/jax_md/_nn/nequip.py", line 20, in
import e3nn_jax as e3nn
File "/home/siamak/anaconda3/lib/python3.10/site-packages/e3nn_jax/init.py", line 111, in
from e3nn_jax import flax, haiku
File "/home/siamak/anaconda3/lib/python3.10/site-packages/e3nn_jax/flax.py", line 1, in
from e3nn_jax._src.linear_flax import Linear
File "/home/siamak/anaconda3/lib/python3.10/site-packages/e3nn_jax/_src/linear_flax.py", line 3, in
import flax
File "/home/siamak/anaconda3/lib/python3.10/site-packages/flax/init.py", line 22, in
from . import core
File "/home/siamak/anaconda3/lib/python3.10/site-packages/flax/core/init.py", line 16, in
from .frozen_dict import (
File "/home/siamak/anaconda3/lib/python3.10/site-packages/flax/core/frozen_dict.py", line 50, in
@jax.tree_util.register_pytree_with_keys_class
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'. Did you mean: 'register_pytree_node_class'?

Any ideas how to fix it?

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant