diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index e94ecaeaf853..a023d886c2a6 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -83,19 +83,11 @@ async def create_async_array_from_callback( def _get_metadata(arr): - if arr.dtype == jnp.bfloat16: - # Tensorstore uses 'bfloat16', not '