From 5aaaa8a895fd6ae4dc1c2401f841abaa82928c89 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Fri, 26 Apr 2024 15:02:05 -0700 Subject: [PATCH] Move dtype settings out of metadata field into the root of Tensorstore spec Before, dtype used to be in the metadata field of tensorstore spec because of it was the legacy way to config the dtype. This setting doesn't understand the "str" name, hence, there was special logic to translate bfloat for example. This CL moves it out of the metadata field and put the dtype directly into the Tensorstore spec to eliminate special dtype translation logic. This will also add support of other quantized types such as int4. PiperOrigin-RevId: 628519426 --- .../array_serialization/serialization.py | 14 +++++--------- .../array_serialization/serialization_test.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index e94ecaeaf853..a023d886c2a6 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -83,19 +83,11 @@ async def create_async_array_from_callback( def _get_metadata(arr): - if arr.dtype == jnp.bfloat16: - # Tensorstore uses 'bfloat16', not '