Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Current code doesn't work with latest flax version and run on CPU only #48

Open
ynahshan opened this issue Jul 12, 2022 · 15 comments
Open

Comments

@ynahshan
Copy link

No description provided.

@nurullahsevim
Copy link

I have the same issue. Any update on this?

@ynahshan
Copy link
Author

Hope the authors will fix this issue otherwise, this repo will be useless. Currently, to bypass it I use this PyTorch implementation of LRA.
https://github.com/pkuzengqi/Skyformer

@fecet
Copy link

fecet commented Nov 14, 2022

This repo use depreciated optim api so it cannot work with latest FLAX, but it can run at GPU with old FLAX, in my case, 0.3.6

@kpe
Copy link

kpe commented Jan 4, 2023

@MostafaDehghani @ynahshan - any plans to update the repo to a newer jax/flax version?

@fecet
Copy link

fecet commented Jan 5, 2023

@MostafaDehghani @ynahshan - any plans to update the repo to a newer jax/flax version?

I do some modification for linformer so it can work with newest flax and the remaining could be done in similar ways. You can pick it up if you are interested https://github.com/fecet/long-range-arena

@DaShenZi721
Copy link

@fecet Hello! Sorry to bother you. Have you ever encountered the following problem? I think it may be related to the version of flax.

Traceback (most recent call last):
  File "lra_benchmarks/listops/train.py", line 28, in <module>
    from flax.deprecated import nn
ModuleNotFoundError: No module named 'flax.deprecated'

@DaShenZi721
Copy link

@fecet This is my setting:

  • cuda: 11.2
  • jax: 0.2.13
  • jaxlib: 0.1.65+cuda112
  • flax: 0.2.2
  • tensorflow: 2.7.0

@fecet
Copy link

fecet commented Mar 8, 2023

@DaShenZi721 'flax.deprecated' only exist in some certain version to temporary store those code would be deprecated, so I guess your version is too old, you can check if deprecated exist in the source code of your flax

@DaShenZi721
Copy link

@fecet Thanks so much! I try to replace from flax.deprecated import nn with from flax import nn, and it works!

@AlexKay28
Copy link

@fecet
Maybe it's a good idea to define strict version of python packages in requirements and other tools'? Or even create a docker image especially for this repo?

Jax, flax libraries are changing very fast and current scripts outdated quite fast .

Right now I can't even launch commands from README =(

@fecet
Copy link

fecet commented Apr 19, 2023

@fecet Maybe it's a good idea to define strict version of python packages in requirements and other tools'? Or even create a docker image especially for this repo?

Jax, flax libraries are changing very fast and current scripts outdated quite fast .

Right now I can't even launch commands from README =(

Apparently, Google has already abandoned this project, and I (and we) are powerless to do anything about it 😄

@DaShenZi721
Copy link

DaShenZi721 commented Apr 20, 2023

Hello @AlexKay28! @fecet is right, Google's team has already given up on maintaining this project, but I was able to run it successfully.
I replaced from flax.deprecated import nn with from flax import nn.
My python version is 3.8.16
Below is my conda environment.

absl-py==1.4.0
appdirs==1.4.4
astunparse==1.6.3
attrs==22.2.0
beautifulsoup4==4.12.2
blessed==1.20.0
cachetools==4.2.4
certifi @ file:///croot/certifi_1671487769961/work/certifi
charset-normalizer==3.1.0
click==8.1.3
cloudpickle==2.2.1
conda-pack==0.6.0
contextlib2==21.6.0
cycler==0.11.0
decorator==5.1.1
dill==0.3.6
dm-tree==0.1.8
docker-pycreds==0.4.0
einops==0.6.1
filelock==3.11.0
fire==0.5.0
flatbuffers==2.0.7
flax==0.3.0
fonttools==4.39.0
future==0.18.3
gast==0.3.3
gdown==4.7.1
gin-config==0.5.0
gitdb==4.0.10
GitPython==3.1.31
google-auth==1.35.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.58.0
gpustat==1.0.0
grpcio==1.51.3
h5py==2.10.0
idna==3.4
importlib-metadata==6.0.0
importlib-resources==5.12.0
jax==0.2.16
jax-smi==1.0.3
jaxlib==0.1.67+cuda111
keras==2.11.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.4
libclang==15.0.6.1
Markdown==3.4.1
MarkupSafe==2.1.2
matplotlib==3.5.3
ml-collections==0.1.1
msgpack==1.0.4
numpy==1.21.0
nvidia-ml-py==11.495.46
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.0
pathtools==0.1.2
Pillow==9.4.0
promise==2.3
protobuf==3.19.6
psutil==5.9.4
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==3.0.9
PySocks==1.7.1
python-dateutil==2.8.2
PyYAML==6.0
requests==2.28.2
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.9.3
sentry-sdk==1.16.0
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
soupsieve==2.4.1
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.0.0
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.31.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.19.0
tensorflow-text==2.11.0
termcolor==2.2.0
tqdm==4.65.0
typing_extensions==4.5.0
urllib3==1.26.14
wandb==0.13.11
wcwidth==0.2.6
Werkzeug==2.2.3
wrapt==1.15.0
zipp==3.15.0

@arneeichholtz
Copy link

arneeichholtz commented Apr 26, 2023

Yes, @DaShenZi721 I had the same problem with the import flax.deprecated. Trying to run the byte-level text classification task.

Your option from flax import nn does not work for me, however. I found from flax import linen as nn to work and installing jax==0.4.0, jaxlib==0.4.0 and flax==0.5.3. This is because GlobalDeviceArray is replaced by jax.Array from jax==0.4.1, so the older version is compatible with GlobalDeviceArray in the code.

Then the dataset is successfully loaded, the config settings are printed, but when calling create_model() (from text_classification/train.py) the following error is raised:

File "lra_benchmarks/text_classification/train.py", line 66, in _create_model
     module = flax_module.partial(**model_kwargs)
 AttributeError: type object 'LinearTransformerEncoder' has no attribute 'partial'

Do you know a way around this? How far have you gotten? Any help is greatly appreciated! Thanks, Arne

@DaShenZi721
Copy link

Hello, @arneeichholtz! I have encountered this error before. It is still a version issue in Flax.
In version 0.4.0 of Flax, the usage of Module.partial has changed. You can modify your code according to the link below.
https://flax.readthedocs.io/en/latest/advanced_topics/linen_upgrade_guide.html#module-partial-inside-other-modules

@arneeichholtz
Copy link

arneeichholtz commented Apr 27, 2023

Thanks for the response! The partial call is resolved but now I'm trying to modify the nn.stochastic call, as this is also deprecated. The link you sent gives an explanation how to do it but I can't really figure it out. Are you running into the same problem or have resolved it? Thanks, Arne

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants