Skip to content

Commit

Permalink
Upgrade cuDNN version in Colabs to support JAX 0.4.3. Add missing "Op…
Browse files Browse the repository at this point in the history
…en in Colab button".

PiperOrigin-RevId: 508779352
  • Loading branch information
romanngg committed Feb 11, 2023
1 parent adaf9e2 commit 46306b6
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 13 deletions.
Expand Up @@ -43,7 +43,7 @@
"outputs": [],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/elementwise.ipynb
Expand Up @@ -66,7 +66,7 @@
],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/empirical_ntk_fcn.ipynb
Expand Up @@ -80,7 +80,7 @@
"source": [
"# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`\n",
"!pip install -q --upgrade pip\n",
"!pip install -q jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
3 changes: 1 addition & 2 deletions notebooks/empirical_ntk_resnet.ipynb
Expand Up @@ -70,9 +70,8 @@
},
"outputs": [],
"source": [
"# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`\n",
"!pip install -q --upgrade pip\n",
"!pip install -q jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://github.com/google/flax\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
Expand Down
13 changes: 11 additions & 2 deletions notebooks/experimental/empirical_ntk_resnet_tf.ipynb
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -204,9 +214,8 @@
}
],
"source": [
"# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`\n",
"!pip install --upgrade pip\n",
"!pip install --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install git+https://www.github.com/google/neural-tangents.git"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/function_space_linearization.ipynb
Expand Up @@ -80,7 +80,7 @@
"outputs": [],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q tensorflow-datasets\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
Expand Down
2 changes: 1 addition & 1 deletion notebooks/learning_curves.ipynb
Expand Up @@ -82,7 +82,7 @@
],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/myrtle_kernel_with_neural_tangents.ipynb
Expand Up @@ -43,7 +43,7 @@
"outputs": [],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/neural_tangents_cookbook.ipynb
Expand Up @@ -53,7 +53,7 @@
"outputs": [],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/phase_diagram.ipynb
Expand Up @@ -75,7 +75,7 @@
],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/weight_space_linearization.ipynb
Expand Up @@ -80,7 +80,7 @@
"outputs": [],
"source": [
"!pip install -q --upgrade pip\n",
"!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!pip install -q tensorflow-datasets\n",
"!pip install -q git+https://www.github.com/google/neural-tangents"
]
Expand Down

0 comments on commit 46306b6

Please sign in to comment.