Skip to content

Commit

Permalink
Recommend the plugin in the CUDA installation instructions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631524916
  • Loading branch information
jyingl3 authored and jax authors committed May 7, 2024
1 parent 5f70267 commit d84c7ad
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ Some standouts:
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12]"` |
| NVIDIA GPU on x86_64 legacy | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
Expand Down
9 changes: 9 additions & 0 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

# NVIDIA CUDA 12 installation legacy
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

Expand Down Expand Up @@ -202,6 +205,12 @@ pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/lib

- `jaxlib` NVIDIA GPU (CUDA 12):

```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html

- `jaxlib` NVIDIA GPU (CUDA 12) legacy:

```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```
Expand Down

0 comments on commit d84c7ad

Please sign in to comment.