Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax needs to be upgraded in the tensorflow/jax image #489

Closed
dhruvbalwada opened this issue Oct 5, 2023 · 10 comments · Fixed by #514
Closed

Flax needs to be upgraded in the tensorflow/jax image #489

dhruvbalwada opened this issue Oct 5, 2023 · 10 comments · Fixed by #514

Comments

@dhruvbalwada
Copy link
Member

Describe the bug
Current version of flax (0.6.1) on the image does not work properly with the jax version (0.4.13).
T

To Reproduce
Issue can be reproduced by doing from flax.training import checkpoints, which will give the error ModuleNotFoundError: No module named 'jax.experimental.global_device_array'.
This has been discussed in google/flax#3087.

Expected behavior
Flax should be importable.

Infrastructure (Where you are running this image):

Solution
At the moment I solve this doing
pip install flax==0.6.10

@weiji14
Copy link
Member

weiji14 commented Oct 5, 2023

The latest version of flax on conda-forge is still 0.6.1, see https://anaconda.org/conda-forge/flax/files. It seems like the package has fallen out of maintenance for a while, but there's some effort to bump up to v0.7.x, see e.g. conda-forge/flax-feedstock#29

We could either wait for the updates on flax-feedstock, or use pip to install flax in the ml-noteboook/environment.yml file as a temporary measure.

@weiji14
Copy link
Member

weiji14 commented Nov 13, 2023

OK, looks like flax=0.7.4 and flax=0.7.5 is available on conda-forge now. However, there's some dependency conflict with tensorflow. I tried with this environment.yml file:

name: pangeo
channels:
 - conda-forge
 - nodefaults
dependencies:
 - flax>=0.7.0
 - jax
 - jupyterlab-nvdashboard
 - keras-cv
 - tensorflow>=2.13.1=*cuda112*

and got this traceback:

Locking dependencies for ['linux-64']...
INFO:conda_lock.conda_solver:linux-64 using specs ['flax >=0.7.0', 'jax', 'jupyterlab-nvdashboard', 'keras-cv', 'tensorflow >=2.13.1 *cuda112*', 'adlfs', 'argopy', 'awscli', 'boto3', 'bottleneck', 'cartopy', 'cdsapi', 'cfgrib', 'ciso', 'cmocean', 'dask-ml', 'datashader', 'descartes', 'earthaccess', 'eofs', 'erddapy', 'esmpy', 'fastjmd95', 'flox', 'fsspec', 'gcm_filters', 'gcsfs', 'gh', 'gh-scoped-creds', 'geocube', 'geopandas', 'geopy', 'geoviews-core', 'git-lfs', 'gsw', 'h5netcdf', 'h5py', 'holoviews', 'hvplot', 'intake', 'intake-esm', 'intake-geopandas', 'intake-stac', 'intake-xarray', 'ipykernel', 'ipyleaflet', 'ipytree', 'ipywidgets', 'jupyterlab-git', 'jupyter-panel-proxy', 'jupyter-resource-usage', 'kerchunk', 'lxml', 'lz4', 'matplotlib-base', 'metpy', 'nb_conda_kernels', 'nbstripout', 'nc-time-axis', 'netcdf4', 'numbagg', 'numcodecs', 'numpy', 'numpy_groupies', 'odc-stac', 'pandas', 'panel', 'parcels', 'param', 'pop-tools', 'pyarrow', 'pycamhd', 'pydap', 'pystac', 'pystac-client', 'python-blosc', 'python-gist', 'python-graphviz', 'rasterio', 'rechunker', 'rio-cogeo', 'rioxarray', 's3fs', 'satpy', 'scikit-image', 'scikit-learn', 'scipy', 'seaborn', 'sparse', 'snakeviz', 'stackstac', 'tiledb-py', 'timezonefinder', 'xarray', 'xarrayutils', 'xarray-datatree', 'xarray_leaflet', 'xarray-spatial', 'xbatcher', 'xcape', 'xclim', 'xesmf', 'xgboost', 'xgcm', 'xhistogram', 'xmip', 'xmitgcm', 'xpublish', 'xrft', 'xskillscore', 'zarr', 'python 3.11.*', 'pangeo-notebook 2023.11.11.*', 'pip']
Failed to parse json, Expecting value: line 1 column 1 (char 0)
Could not lock the environment for platform linux-64
Could not solve for environment specs
The following packages are incompatible
├─ flax >=0.7.0  is installable and it requires
│  └─ tensorstore   with the potential options
│     ├─ tensorstore [0.1.44|0.1.46|0.1.47|0.1.48] would require
│     │  └─ libprotobuf [>=4.24.3,<4.24.4.0a0 |>=4.24.4,<4.24.5.0a0 ], which requires
│     │     └─ libabseil >=20230802.1,<20230803.0a0 , which can be installed;
│     └─ tensorstore 0.1.44 would require
│        └─ libprotobuf >=4.23.4,<4.23.5.0a0  with the potential options
│           ├─ libprotobuf 4.23.4, which can be installed;
│           └─ libprotobuf 4.23.4 would require
│              └─ libabseil >=20230802.0,<20230803.0a0 , which can be installed;
└─ tensorflow >=2.13.1 *cuda112* is uninstallable because it requires
   └─ tensorflow-base [2.13.1 cuda112py310hbb601f2_1|2.13.1 cuda112py311h8bdbb6c_1|2.13.1 cuda112py38h79651c7_1|2.13.1 cuda112py39h85a252b_1], which requires
      ├─ libabseil >=20230125.3,<20230126.0a0 , which conflicts with any installable versions previously reported;
      └─ libgrpc >=1.54.3,<1.55.0a0 , which requires
         └─ libprotobuf >=3.21.12,<3.22.0a0 , which conflicts with any installable versions previously reported.

It seems like flax -> tensorstore is depending on libprotobuf~=4.24, but tensorflow-base -> libgrpc is depending on libprotobuf~=3.21. Seems like a longstanding issue at conda-forge/tensorflow-feedstock#288, and the migration at either conda-forge/tensorflow-feedstock#347 or conda-forge/tensorflow-feedstock#342 might help.

Edit: Might need to wait for tensorflow 2.15.0 at conda-forge/tensorflow-feedstock#353? Nope, conda-forge's tensorflow=2.15.0 doesn't help. They still need to handle the libprotobuf issue, wait for the bot re-run after conda-forge/tensorflow-feedstock#359. See conda-forge/tensorflow-feedstock#361 🙏 (PR was merged, but package later marked as broken according to conda-forge/tensorflow-feedstock#367 (comment) 😅).

@weiji14
Copy link
Member

weiji14 commented Jan 26, 2024

Looks like we might need to upgrade from CUDA 11.8 to 12 to get a newer version of tensorflow=2.15.0 from conda-forge with libprotobuf~=4.24 that works with flax>=0.7.4, see conda-forge/tensorflow-feedstock#367 (comment).

@dhruvbalwada
Copy link
Member Author

dhruvbalwada commented Feb 5, 2024

@yuvipanda and @jbusecke - just getting your attention here, since my hacky workflows seems to have stopped working today.

Recently to work with jax and appropriate version of flax, I would do 2 steps on the leap hub:

  • mamba install cuda-nvcc==11.6.* -c nvidia
  • pip install flax==0.6.10

However, today morning this leads to a new error. After following these steps, I get the error:
RuntimeError: jaxlib is version 0.4.12, but this version of jax requires version >= 0.4.19. when I try to do an import jax.

@jbusecke
Copy link
Collaborator

jbusecke commented Feb 5, 2024

Hey @dhruvbalwada, I suspect this is due to the recent update of the pangeo-docker-image on the LEAP hub.

To unblock you for now, I recommend you manually run from an older image (the LEAP docs provide instructions).

But this does not change the core problem here I think. Anything I could help/test to contribute here @weiji14?

@weiji14
Copy link
Member

weiji14 commented Feb 5, 2024

Right, looks like we'll need to expedite the upgrade to CUDA 12 then as mentioned at #489 (comment). Let me open a PR for that (got some free time today), and then we'll be able to upgrade to newer tenforflow/flax versions.

@weiji14
Copy link
Member

weiji14 commented Feb 6, 2024

Ok, not as simple as I thought. I tried running conda-lock to create a lockfile with a newer version of tensorflow=2.15.0 and flax>=0.8.0 with CUDA 12.0, but it errors with:

The following packages are incompatible
├─ flax >=0.8.0  is installable and it requires
│  └─ jax >=0.4.11  with the potential options
│     ├─ jax 0.4.11 would require
│     │  └─ jaxlib >=0.4.7  with the potential options
│     │     ├─ jaxlib [0.4.10|0.4.11|0.4.12|0.4.14|0.4.9] would require
│     │     │  └─ libgrpc >=1.54.2,<1.55.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.12 would require
│     │     │  └─ libgrpc >=1.56.0,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.56.2,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.14|0.4.18|0.4.19] would require
│     │     │  └─ libgrpc >=1.58.1,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.57.0,<1.58.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.20|0.4.23] would require
│     │     │  └─ libgrpc >=1.58.2,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.7 would require
│     │     │  └─ libgrpc >=1.52.1,<1.53.0a0 , which can be installed;
│     │     └─ jaxlib 0.4.7 would require
│     │        └─ libgrpc >=1.54.0,<1.55.0a0 , which can be installed;
│     ├─ jax [0.4.12|0.4.13|0.4.14] would require
│     │  └─ jaxlib >=0.4.11 , which can be installed (as previously explained);
│     ├─ jax [0.4.16|0.4.17|0.4.19|0.4.20] would require
│     │  └─ jaxlib >=0.4.14 , which can be installed (as previously explained);
│     └─ jax [0.4.21|0.4.23] would require
│        └─ jaxlib >=0.4.19 , which can be installed (as previously explained);
└─ tensorflow >=2.15.0 *cuda120* is not installable because it requires
   └─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
      └─ libgrpc >=1.59.3,<1.60.0a0 , which conflicts with any installable versions previously reported.

It looks like we'll need to wait for jaxlib to support CUDA 12 (conda-forge/jaxlib-feedstock#223, conda-forge/jaxlib-feedstock#218), and also be rebuilt with libprotobuf 4.24 (conda-forge/jaxlib-feedstock#221).

@dhruvbalwada
Copy link
Member Author

Do you think for now a slightly older version may be enough? maybe flax>=0.7?

@weiji14
Copy link
Member

weiji14 commented Feb 6, 2024

Nope, flax>=0.7.0 doesn't work either

├─ flax >=0.7.0  is installable and it requires
│  └─ jax >=0.4.11  with the potential options
│     ├─ jax 0.4.11 would require
│     │  └─ jaxlib >=0.4.7  with the potential options
│     │     ├─ jaxlib [0.4.10|0.4.11|0.4.12|0.4.14|0.4.9] would require
│     │     │  └─ libgrpc >=1.54.2,<1.55.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.12 would require
│     │     │  └─ libgrpc >=1.56.0,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.56.2,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.14|0.4.18|0.4.19] would require
│     │     │  └─ libgrpc >=1.58.1,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.57.0,<1.58.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.20|0.4.23] would require
│     │     │  └─ libgrpc >=1.58.2,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.7 would require
│     │     │  └─ libgrpc >=1.52.1,<1.53.0a0 , which can be installed;
│     │     └─ jaxlib 0.4.7 would require
│     │        └─ libgrpc >=1.54.0,<1.55.0a0 , which can be installed;
│     ├─ jax [0.4.12|0.4.13|0.4.14] would require
│     │  └─ jaxlib >=0.4.11 , which can be installed (as previously explained);
│     ├─ jax [0.4.16|0.4.17|0.4.19|0.4.20] would require
│     │  └─ jaxlib >=0.4.14 , which can be installed (as previously explained);
│     └─ jax [0.4.21|0.4.23] would require
│        └─ jaxlib >=0.4.19 , which can be installed (as previously explained);
└─ tensorflow >=2.15.0 *cuda120* is not installable because it requires
   └─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
      └─ libgrpc >=1.59.3,<1.60.0a0 , which conflicts with any installable versions previously reported.

I've also tried older version combinations with CUDA 11.2 and tensorflow 2.13.x last year (see all my crossed out links in #489 (comment)), but they all don't work. We really need to get all the tensorflow/jax libraries to align on the correct version of libprotobuf in conda-forge.

@dhruvbalwada
Copy link
Member Author

The new hack that is working is :

pip install 'flax==0.7.2' 'jax<=0.4.13' 'ml_dtypes==0.2.0'
mamba install cuda-nvcc==11.6.* -c nvidia

Hopefully alignment will come in near future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants