Skip to content

Commit

Permalink
Move dtype settings out of metadata field into the root of Tensorstor…
Browse files Browse the repository at this point in the history
…e spec

Before, dtype used to be in the metadata field of tensorstore spec because of it was the legacy way to config the dtype.  This setting doesn't understand the "str" name, hence, there was special logic to translate bfloat for example.

This CL moves it out of the metadata field and put the dtype directly into the Tensorstore spec to eliminate special dtype translation logic.  This will also add support of other quantized types such as int4.

PiperOrigin-RevId: 628519426
  • Loading branch information
ChromeHearts authored and jax authors committed May 1, 2024
1 parent e691c19 commit 5aaaa8a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
14 changes: 5 additions & 9 deletions jax/experimental/array_serialization/serialization.py
Expand Up @@ -83,19 +83,11 @@ async def create_async_array_from_callback(


def _get_metadata(arr):
if arr.dtype == jnp.bfloat16:
# Tensorstore uses 'bfloat16', not '<V2'.
dtype = 'bfloat16'
else:
dtype = np.dtype(arr.dtype).str
local_shape = arr.addressable_data(0).shape
return {
'compressor': {
'id': 'zstd'
},
'compressor': {'id': 'zstd'},
'shape': arr.shape,
'chunks': np.array(np.maximum(1, local_shape)),
'dtype': dtype,
}


Expand Down Expand Up @@ -220,6 +212,10 @@ async def async_serialize(
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_metadata(arr_inp)

# Set dtype if it's not in spec
if 'dtype' not in tensorstore_spec:
tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name

# If primary_host is None, all hosts will checkpoint. This is used
# for checkpointing to local filesystem.
if primary_host is None or jax.process_index() == primary_host:
Expand Down
14 changes: 10 additions & 4 deletions jax/experimental/array_serialization/serialization_test.py
Expand Up @@ -205,12 +205,15 @@ def cb3(_):
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
self.assertEqual(m3.dtype, np.float32)

def test_checkpointing_with_bigger_shape_jax_array(self):
@parameterized.product(input_dtype=[np.int32, jax.numpy.bfloat16])
def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = math.prod(global_input_shape)

global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
global_input_data1 = np.arange(num, dtype=input_dtype).reshape(
global_input_shape
)
def cb1(index):
return global_input_data1[index]
arr = array.make_array_from_callback(
Expand Down Expand Up @@ -250,12 +253,15 @@ def cb1(index):
for l in m2.addressable_shards:
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))

def test_checkpointing_with_int4(self):
@parameterized.product(input_dtype=[jax.numpy.int4, jax.numpy.int8])
def test_checkpointing_with_int4(self, input_dtype):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = math.prod(global_input_shape)

global_input_data = np.arange(num, dtype=jax.numpy.int8).reshape(global_input_shape)
global_input_data = np.arange(num, dtype=input_dtype).reshape(
global_input_shape
)
def cb(index):
return global_input_data[index]
arr = array.make_array_from_callback(
Expand Down

0 comments on commit 5aaaa8a

Please sign in to comment.