diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c1bfaba8756..cdc53685895 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,8 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- Don't allow overwriting index variables with ``to_zarr`` region writes. (:issue:`8589`, :pull:`8876`). + By `Deepak Cherian `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index d3026a535e2..24fc8b116d3 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1562,9 +1562,7 @@ def _auto_detect_regions(ds, region, open_kwargs): return region -def _validate_and_autodetect_region( - ds, region, mode, open_kwargs -) -> tuple[dict[str, slice], bool]: +def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]: if region == "auto": region = {dim: "auto" for dim in ds.dims} @@ -1572,14 +1570,11 @@ def _validate_and_autodetect_region( raise TypeError(f"``region`` must be a dict, got {type(region)}") if any(v == "auto" for v in region.values()): - region_was_autodetected = True if mode != "r+": raise ValueError( f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}" ) region = _auto_detect_regions(ds, region, open_kwargs) - else: - region_was_autodetected = False for k, v in region.items(): if k not in ds.dims: @@ -1612,7 +1607,7 @@ def _validate_and_autodetect_region( f".drop_vars({non_matching_vars!r})" ) - return region, region_was_autodetected + return region def _validate_datatypes_for_zarr_append(zstore, dataset): @@ -1784,12 +1779,9 @@ def to_zarr( storage_options=storage_options, zarr_version=zarr_version, ) - region, region_was_autodetected = _validate_and_autodetect_region( - dataset, region, mode, open_kwargs - ) - # drop indices to avoid potential race condition with auto region - if region_was_autodetected: - dataset = dataset.drop_vars(dataset.indexes) + region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs) + # can't modify indexed with region writes + dataset = dataset.drop_vars(dataset.indexes) if append_dim is not None and append_dim in region: raise ValueError( f"cannot list the same dimension in both ``append_dim`` and " diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3fb137977e8..5d2fefecf48 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -13,12 +13,12 @@ import tempfile import uuid import warnings -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Mapping from contextlib import ExitStack from io import BytesIO from os import listdir from pathlib import Path -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final, Literal, cast from unittest.mock import patch import numpy as np @@ -5641,24 +5641,27 @@ def test_zarr_region_index_write(self, tmp_path): } ) - ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + region_slice = dict(x=slice(2, 4), y=slice(6, 8)) + ds_region = 1 + ds.isel(region_slice) ds.to_zarr(tmp_path / "test.zarr") - with patch.object( - ZarrStore, - "set_variables", - side_effect=ZarrStore.set_variables, - autospec=True, - ) as mock: - ds_region.to_zarr(tmp_path / "test.zarr", region="auto", mode="r+") - - # should write the data vars but never the index vars with auto mode - for call in mock.call_args_list: - written_variables = call.args[1].keys() - assert "test" in written_variables - assert "x" not in written_variables - assert "y" not in written_variables + region: Mapping[str, slice] | Literal["auto"] + for region in [region_slice, "auto"]: # type: ignore + with patch.object( + ZarrStore, + "set_variables", + side_effect=ZarrStore.set_variables, + autospec=True, + ) as mock: + ds_region.to_zarr(tmp_path / "test.zarr", region=region, mode="r+") + + # should write the data vars but never the index vars with auto mode + for call in mock.call_args_list: + written_variables = call.args[1].keys() + assert "test" in written_variables + assert "x" not in written_variables + assert "y" not in written_variables def test_zarr_region_append(self, tmp_path): x = np.arange(0, 50, 10)