Skip to content

Commit

Permalink
Don't allow overwriting indexes with region writes (#8877)
Browse files Browse the repository at this point in the history
* Don't allow overwriting indexes with region writes

Closes #8589

* Fix typing

* one more typing fix
  • Loading branch information
dcherian committed Mar 27, 2024
1 parent 473b87f commit cf36559
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 30 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Expand Up @@ -41,6 +41,8 @@ New Features
Breaking changes
~~~~~~~~~~~~~~~~

- Don't allow overwriting index variables with ``to_zarr`` region writes. (:issue:`8589`, :pull:`8876`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Deprecations
~~~~~~~~~~~~
Expand Down
18 changes: 5 additions & 13 deletions xarray/backends/api.py
Expand Up @@ -1562,24 +1562,19 @@ def _auto_detect_regions(ds, region, open_kwargs):
return region


def _validate_and_autodetect_region(
ds, region, mode, open_kwargs
) -> tuple[dict[str, slice], bool]:
def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]:
if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")

if any(v == "auto" for v in region.values()):
region_was_autodetected = True
if mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}"
)
region = _auto_detect_regions(ds, region, open_kwargs)
else:
region_was_autodetected = False

for k, v in region.items():
if k not in ds.dims:
Expand Down Expand Up @@ -1612,7 +1607,7 @@ def _validate_and_autodetect_region(
f".drop_vars({non_matching_vars!r})"
)

return region, region_was_autodetected
return region


def _validate_datatypes_for_zarr_append(zstore, dataset):
Expand Down Expand Up @@ -1784,12 +1779,9 @@ def to_zarr(
storage_options=storage_options,
zarr_version=zarr_version,
)
region, region_was_autodetected = _validate_and_autodetect_region(
dataset, region, mode, open_kwargs
)
# drop indices to avoid potential race condition with auto region
if region_was_autodetected:
dataset = dataset.drop_vars(dataset.indexes)
region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs)
# can't modify indexed with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
Expand Down
37 changes: 20 additions & 17 deletions xarray/tests/test_backends.py
Expand Up @@ -13,12 +13,12 @@
import tempfile
import uuid
import warnings
from collections.abc import Generator, Iterator
from collections.abc import Generator, Iterator, Mapping
from contextlib import ExitStack
from io import BytesIO
from os import listdir
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final, cast
from typing import TYPE_CHECKING, Any, Final, Literal, cast
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -5651,24 +5651,27 @@ def test_zarr_region_index_write(self, tmp_path):
}
)

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
region_slice = dict(x=slice(2, 4), y=slice(6, 8))
ds_region = 1 + ds.isel(region_slice)

ds.to_zarr(tmp_path / "test.zarr")

with patch.object(
ZarrStore,
"set_variables",
side_effect=ZarrStore.set_variables,
autospec=True,
) as mock:
ds_region.to_zarr(tmp_path / "test.zarr", region="auto", mode="r+")

# should write the data vars but never the index vars with auto mode
for call in mock.call_args_list:
written_variables = call.args[1].keys()
assert "test" in written_variables
assert "x" not in written_variables
assert "y" not in written_variables
region: Mapping[str, slice] | Literal["auto"]
for region in [region_slice, "auto"]: # type: ignore
with patch.object(
ZarrStore,
"set_variables",
side_effect=ZarrStore.set_variables,
autospec=True,
) as mock:
ds_region.to_zarr(tmp_path / "test.zarr", region=region, mode="r+")

# should write the data vars but never the index vars with auto mode
for call in mock.call_args_list:
written_variables = call.args[1].keys()
assert "test" in written_variables
assert "x" not in written_variables
assert "y" not in written_variables

def test_zarr_region_append(self, tmp_path):
x = np.arange(0, 50, 10)
Expand Down

0 comments on commit cf36559

Please sign in to comment.