Skip to content

Commit

Permalink
Allow rechunking to groups (#144)
Browse files Browse the repository at this point in the history
* allow rechunking to groups

* remove print

* install all deps

* add tests for error states

* typos [ci skip]
  • Loading branch information
rabernat committed Aug 12, 2023
1 parent ba7efc0 commit 0ac43e8
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
pip install ".[dev,complete]"
- name: Test with pytest
run: |
py.test tests -v --cov=rechunker --cov-config .coveragerc --cov-report term-missing
Expand Down
46 changes: 40 additions & 6 deletions rechunker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,13 @@ def _encode_zarr_attributes(attrs):

def _zarr_empty(shape, store_or_group, chunks, dtype, name=None, **kwargs):
# wrapper that maybe creates the array within a group
if name is not None:
assert isinstance(store_or_group, zarr.hierarchy.Group)
if isinstance(store_or_group, zarr.Group):
assert name is not None
return store_or_group.empty(
name, shape=shape, chunks=chunks, dtype=dtype, **kwargs
)
else:
# ignore name
return zarr.empty(
shape, chunks=chunks, dtype=dtype, store=store_or_group, **kwargs
)
Expand Down Expand Up @@ -222,6 +223,7 @@ def rechunk(
temp_store=None,
temp_options=None,
executor: Union[str, CopySpecExecutor] = "dask",
array_name=None,
) -> Rechunked:
"""
Rechunk a Zarr Array or Group, a Dask Array, or an Xarray Dataset
Expand Down Expand Up @@ -287,6 +289,10 @@ def rechunk(
* python
* pywren
array_name: str, optional
Required when rechunking an array if any of the targets is a group
Returns
-------
rechunked : :class:`Rechunked` object
Expand All @@ -302,6 +308,7 @@ def rechunk(
target_options=target_options,
temp_store=temp_store,
temp_options=temp_options,
array_name=array_name,
)
plan = executor.prepare_plan(copy_spec)
return Rechunked(executor, plan, source, intermediate, target)
Expand Down Expand Up @@ -363,6 +370,7 @@ def _setup_rechunk(
target_options=None,
temp_store=None,
temp_options=None,
array_name=None,
):
if temp_options is None:
temp_options = target_options
Expand All @@ -387,15 +395,26 @@ def _setup_rechunk(
raise ValueError(
"You must specify ``target-chunks`` as a dict when rechunking a dataset."
)
if array_name is not None:
raise ValueError(
"Can't specify `array_name` when rechunking an Xarray Dataset."
)

variables, attrs = encode_dataset_coordinates(source)
attrs = _encode_zarr_attributes(attrs)

if temp_store is not None:
temp_group = zarr.group(temp_store)
if isinstance(temp_store, zarr.Group):
temp_group = temp_store
else:
temp_group = zarr.group(temp_store)
else:
temp_group = None
target_group = zarr.group(target_store)

if isinstance(target_store, zarr.Group):
target_group = target_store
else:
target_group = zarr.group(target_store)
target_group.attrs.update(attrs)

# if ``target_chunks`` is specified per dimension (xarray ``.rechunk`` style),
Expand Down Expand Up @@ -462,12 +481,21 @@ def _setup_rechunk(
raise ValueError(
"You must specify ``target-chunks`` as a dict when rechunking a group."
)
if array_name is not None:
raise ValueError("Can't specify `array_name` when rechunking a Group.")

if temp_store is not None:
temp_group = zarr.group(temp_store)
if isinstance(temp_store, zarr.Group):
temp_group = temp_store
else:
temp_group = zarr.group(temp_store)
else:
temp_group = None
target_group = zarr.group(target_store)

if isinstance(target_store, zarr.Group):
target_group = target_store
else:
target_group = zarr.group(target_store)
_copy_group_attributes(source, target_group)
target_group.attrs.update(source.attrs)

Expand All @@ -488,6 +516,11 @@ def _setup_rechunk(
return copy_specs, temp_group, target_group

elif isinstance(source, (zarr.core.Array, dask.array.Array)):
if (
isinstance(target_store, zarr.Group) or isinstance(temp_store, zarr.Group)
) and array_name is None:
raise ValueError("Can't rechunk to a group without a name for the array.")

copy_spec = _setup_array_rechunk(
source,
target_chunks,
Expand All @@ -496,6 +529,7 @@ def _setup_rechunk(
target_options=target_options,
temp_store_or_group=temp_store,
temp_options=temp_options,
name=array_name,
)
intermediate = copy_spec.intermediate.array
target = copy_spec.write.array
Expand Down
136 changes: 105 additions & 31 deletions tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,32 @@ def test_get_dim_chunk(dask_chunks, chunk_ds, dim, target_chunks, expected):
assert chunk == expected


@pytest.fixture(params=["string_path", "mapper", "group"])
def target_store(tmp_path, request):
if request.param == "mapper":
pytest.importorskip("fsspec")
return FSStore(str(tmp_path) + "target.zarr")
elif request.param == "group":
pytest.importorskip("fsspec")
store = FSStore(str(tmp_path) + "group.target.zarr")
return zarr.group(store)
else:
return str(tmp_path / "mapper.target.zarr")


@pytest.fixture(params=["string_path", "mapper", "group"])
def temp_store(tmp_path, request):
if request.param == "mapper":
pytest.importorskip("fsspec")
return FSStore(str(tmp_path) + "temp.zarr")
elif request.param == "group":
pytest.importorskip("fsspec")
store = FSStore(str(tmp_path) + "group.temp.zarr")
return zarr.group(store)
else:
return str(tmp_path / "mapper.temp.zarr")


@pytest.mark.parametrize("shape", [(100, 50)])
@pytest.mark.parametrize("source_chunks", [(10, 50)])
@pytest.mark.parametrize(
Expand All @@ -194,10 +220,7 @@ def test_get_dim_chunk(dask_chunks, chunk_ds, dim, target_chunks, expected):
)
@pytest.mark.parametrize("max_mem", ["10MB"])
@pytest.mark.parametrize("executor", ["dask", "python", requires_prefect("prefect")])
@pytest.mark.parametrize("target_store", ["target.zarr", "mapper.target.zarr"])
@pytest.mark.parametrize("temp_store", ["temp.zarr", "mapper.temp.zarr"])
def test_rechunk_dataset(
tmp_path,
shape,
source_chunks,
target_chunks,
Expand All @@ -208,14 +231,6 @@ def test_rechunk_dataset(
):
xarray = pytest.importorskip("xarray")

if target_store.startswith("mapper"):
pytest.importorskip("fsspec")
target_store = FSStore(str(tmp_path) + target_store)
temp_store = FSStore(str(tmp_path) + temp_store)
else:
target_store = str(tmp_path / target_store)
temp_store = str(tmp_path / temp_store)

ds = example_dataset(shape).chunk(chunks=dict(zip(["x", "y"], source_chunks)))
options = dict(
a=dict(
Expand All @@ -238,14 +253,19 @@ def test_rechunk_dataset(
with dask.config.set(scheduler="single-threaded"):
rechunked.execute()

if isinstance(target_store, zarr.Group):
thing_to_open = target_store.store
else:
thing_to_open = target_store

# Validate encoded variables
dst = xarray.open_zarr(target_store, decode_cf=False)
dst = xarray.open_zarr(thing_to_open, decode_cf=False)
assert dst.a.dtype == options["a"]["dtype"]
assert all(dst.a.values[-1] == options["a"]["_FillValue"])
assert dst.a.encoding["compressor"] is not None

# Validate decoded variables
dst = xarray.open_zarr(target_store, decode_cf=True)
dst = xarray.open_zarr(thing_to_open, decode_cf=True)
target_chunks_expected = (
target_chunks["a"]
if isinstance(target_chunks["a"], tuple)
Expand Down Expand Up @@ -351,7 +371,16 @@ def test_rechunk_dataset_dimchunks(
],
)
def test_rechunk_array(
tmp_path, shape, source_chunks, dtype, dims, target_chunks, max_mem, executor
tmp_path,
shape,
source_chunks,
dtype,
dims,
target_chunks,
max_mem,
executor,
target_store,
temp_store,
):
### Create source array ###
store_source = str(tmp_path / "source.zarr")
Expand All @@ -363,9 +392,10 @@ def test_rechunk_array(
if dims:
source_array.attrs[_DIMENSION_KEY] = dims

### Create targets ###
target_store = str(tmp_path / "target.zarr")
temp_store = str(tmp_path / "temp.zarr")
if isinstance(target_store, zarr.Group) or isinstance(temp_store, zarr.Group):
array_name = "_temp_array"
else:
array_name = None

rechunked = api.rechunk(
source_array,
Expand All @@ -374,10 +404,14 @@ def test_rechunk_array(
target_store,
temp_store=temp_store,
executor=executor,
array_name=array_name,
)
assert isinstance(rechunked, api.Rechunked)

target_array = zarr.open(target_store)
if isinstance(target_store, zarr.Group):
target_array = target_store[array_name]
else:
target_array = zarr.open(target_store, mode="r")

if isinstance(target_chunks, dict):
target_chunks_list = [target_chunks[d] for d in dims]
Expand Down Expand Up @@ -406,21 +440,37 @@ def test_rechunk_array(
],
)
def test_rechunk_dask_array(
tmp_path, shape, source_chunks, dtype, target_chunks, max_mem
tmp_path,
shape,
source_chunks,
dtype,
target_chunks,
max_mem,
target_store,
temp_store,
):
### Create source array ###
source_array = dsa.ones(shape, chunks=source_chunks, dtype=dtype)

### Create targets ###
target_store = str(tmp_path / "target.zarr")
temp_store = str(tmp_path / "temp.zarr")
if isinstance(target_store, zarr.Group) or isinstance(temp_store, zarr.Group):
array_name = "_temp_array"
else:
array_name = None

rechunked = api.rechunk(
source_array, target_chunks, max_mem, target_store, temp_store=temp_store
source_array,
target_chunks,
max_mem,
target_store,
temp_store=temp_store,
array_name=array_name,
)
assert isinstance(rechunked, api.Rechunked)

target_array = zarr.open(target_store)
if isinstance(target_store, zarr.Group):
target_array = target_store[array_name]
else:
target_array = zarr.open(target_store, mode="r")

assert target_array.chunks == tuple(target_chunks)

Expand All @@ -440,18 +490,12 @@ def test_rechunk_dask_array(
],
)
@pytest.mark.parametrize("source_store", ["source.zarr", "mapper.source.zarr"])
@pytest.mark.parametrize("target_store", ["target.zarr", "mapper.target.zarr"])
@pytest.mark.parametrize("temp_store", ["temp.zarr", "mapper.temp.zarr"])
def test_rechunk_group(tmp_path, executor, source_store, target_store, temp_store):
if source_store.startswith("mapper"):
pytest.importorskip("fsspec")
store_source = FSStore(str(tmp_path) + source_store)
target_store = FSStore(str(tmp_path) + target_store)
temp_store = FSStore(str(tmp_path) + temp_store)
else:
store_source = str(tmp_path / source_store)
target_store = str(tmp_path / target_store)
temp_store = str(tmp_path / temp_store)

group = zarr.group(store_source, overwrite=True)
group.create_group("foo/bar/baz")
Expand Down Expand Up @@ -481,7 +525,12 @@ def test_rechunk_group(tmp_path, executor, source_store, target_store, temp_stor
)
assert isinstance(rechunked, api.Rechunked)

target_group = zarr.open(target_store)
if isinstance(target_store, zarr.Group):
thing_to_open = target_store.store
else:
thing_to_open = target_store

target_group = zarr.open(thing_to_open, mode="r")
assert "a" in target_group
assert "foo/bar/baz/b" in target_group
assert dict(group.attrs) == dict(target_group.attrs)
Expand Down Expand Up @@ -741,3 +790,28 @@ def test_no_intermediate_fused(tmp_path):
# rechunked.plan is a list of dask delayed objects
num_tasks = len([v for v in rechunked.plan[0].dask.values() if dask.core.istask(v)])
assert num_tasks < 20 # less than if no fuse


def test_rechunk_array_to_group_no_name(tmp_path):
a = sample_zarr_array(tmp_path)
target_chunks = a.chunks
max_mem = "100MB"
target_group = zarr.group(str(tmp_path) + "/group.zarr")
with pytest.raises(ValueError, match="without a name for the array"):
api.rechunk(a, target_chunks, max_mem, target_group)


def test_rechunk_group_to_group_with_name(tmp_path):
source_group = sample_zarr_group(tmp_path)
target_chunks = {aname: source_group[aname].chunks for aname in source_group}
max_mem = "100MB"
target_group = zarr.group(str(tmp_path) + "/group.zarr")
with pytest.raises(ValueError, match="Can't specify `array_name`"):
api.rechunk(
source_group,
target_chunks,
max_mem,
max_mem,
target_group,
array_name="foo",
)

0 comments on commit 0ac43e8

Please sign in to comment.