You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
AND it gives the LOG:
Installing collected packages: libtpu-nightly, numpy, scipy, opt-einsum, ml-dtypes, jaxlib, jax
WARNING: The script f2py is installed in '/home/dr_preethibaselios/.local/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
Successfully installed jax-0.4.26 jaxlib-0.4.26 libtpu-nightly-0.1.dev20240403 ml-dtypes-0.4.0 numpy-1.26.4 opt-einsum-3.3.0 scipy-1.13.0
But When I run JAX commands, it didnt get correctly:
For example:
set JAX_PLATFORMS='tpu'
set OMP_NUM_THREADS=1
$python3
Python 3.10.6 (main, Mar 10 2023, 10:55:28) [GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>
>>> import jax.numpy as jnp
>>>
>>> from jax import random
>>>
>>> import jax
>>>
>>>
>>> import datetime
>>>
>>>
>>> N = 20000 # the Dimension of Matrix
>>>
>>> x = random.normal(random.PRNGKey(0), (N,N), dtype=jnp.float32)
The log is:
/home/dr/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:143: UserWarning: TPU backend initialization is taking more than 60.0 seconds. Did you run your code on all TPU hosts? See https://jax.readthedocs.io/en/latest/multi_process.html for more information.
warnings.warn(
It just freeze here...
The same commands Ran in v2-8 without a problem!
Do I miss something? Hope some may help on it ...
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Dear List
I created a v2-32 TPU on GCP (with tpu-ubuntu2204-base) in gcloud via,
gcloud compute tpus tpu-vm ssh --zone "europe-west4-a" "node-2" --project "ambient-depth-413708"
I Installed JAX via:
AND it gives the LOG:
Installing collected packages: libtpu-nightly, numpy, scipy, opt-einsum, ml-dtypes, jaxlib, jax
WARNING: The script f2py is installed in '/home/dr_preethibaselios/.local/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
Successfully installed jax-0.4.26 jaxlib-0.4.26 libtpu-nightly-0.1.dev20240403 ml-dtypes-0.4.0 numpy-1.26.4 opt-einsum-3.3.0 scipy-1.13.0
But When I run JAX commands, it didnt get correctly:
For example:
The log is:
/home/dr/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:143: UserWarning: TPU backend initialization is taking more than 60.0 seconds. Did you run your code on all TPU hosts? See https://jax.readthedocs.io/en/latest/multi_process.html for more information.
warnings.warn(
It just freeze here...
The same commands Ran in v2-8 without a problem!
Do I miss something? Hope some may help on it ...
Thanks in advance
Mohan
Beta Was this translation helpful? Give feedback.
All reactions