Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow overwriting indexes with region writes #8877

Merged
merged 3 commits into from Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Expand Up @@ -38,6 +38,8 @@ New Features
Breaking changes
~~~~~~~~~~~~~~~~

- Don't allow overwriting index variables with ``to_zarr`` region writes. (:issue:`8589`, :pull:`8876`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Deprecations
~~~~~~~~~~~~
Expand Down
18 changes: 5 additions & 13 deletions xarray/backends/api.py
Expand Up @@ -1562,24 +1562,19 @@ def _auto_detect_regions(ds, region, open_kwargs):
return region


def _validate_and_autodetect_region(
ds, region, mode, open_kwargs
) -> tuple[dict[str, slice], bool]:
def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]:
if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")

if any(v == "auto" for v in region.values()):
region_was_autodetected = True
if mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}"
)
region = _auto_detect_regions(ds, region, open_kwargs)
else:
region_was_autodetected = False

for k, v in region.items():
if k not in ds.dims:
Expand Down Expand Up @@ -1612,7 +1607,7 @@ def _validate_and_autodetect_region(
f".drop_vars({non_matching_vars!r})"
)

return region, region_was_autodetected
return region


def _validate_datatypes_for_zarr_append(zstore, dataset):
Expand Down Expand Up @@ -1784,12 +1779,9 @@ def to_zarr(
storage_options=storage_options,
zarr_version=zarr_version,
)
region, region_was_autodetected = _validate_and_autodetect_region(
dataset, region, mode, open_kwargs
)
# drop indices to avoid potential race condition with auto region
if region_was_autodetected:
dataset = dataset.drop_vars(dataset.indexes)
region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs)
# can't modify indexed with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
Expand Down
37 changes: 20 additions & 17 deletions xarray/tests/test_backends.py
Expand Up @@ -13,12 +13,12 @@
import tempfile
import uuid
import warnings
from collections.abc import Generator, Iterator
from collections.abc import Generator, Iterator, Mapping
from contextlib import ExitStack
from io import BytesIO
from os import listdir
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final, cast
from typing import TYPE_CHECKING, Any, Final, Literal, cast
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -3935,7 +3935,7 @@
["combine_attrs", "attrs", "expected", "expect_error"],
(
pytest.param("drop", [{"a": 1}, {"a": 2}], {}, False, id="drop"),
pytest.param(

Check failure on line 3938 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestOpenMFDatasetWithDataVarsAndCoordsKw.test_open_mfdataset_does_same_as_concat[inner-minimal-by_coords-None] RuntimeError: NetCDF: Not a valid ID

Check failure on line 3938 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestOpenMFDatasetWithDataVarsAndCoordsKw.test_open_mfdataset_does_same_as_concat[inner-different-nested-t] RuntimeError: NetCDF: Not a valid ID

Check failure on line 3938 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestOpenMFDatasetWithDataVarsAndCoordsKw.test_open_mfdataset_does_same_as_concat[left-all-by_coords-None] RuntimeError: NetCDF: Not a valid ID
"override", [{"a": 1}, {"a": 2}], {"a": 1}, False, id="override"
),
pytest.param(
Expand Down Expand Up @@ -5641,24 +5641,27 @@
}
)

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
region_slice = dict(x=slice(2, 4), y=slice(6, 8))
ds_region = 1 + ds.isel(region_slice)

ds.to_zarr(tmp_path / "test.zarr")

with patch.object(
ZarrStore,
"set_variables",
side_effect=ZarrStore.set_variables,
autospec=True,
) as mock:
ds_region.to_zarr(tmp_path / "test.zarr", region="auto", mode="r+")

# should write the data vars but never the index vars with auto mode
for call in mock.call_args_list:
written_variables = call.args[1].keys()
assert "test" in written_variables
assert "x" not in written_variables
assert "y" not in written_variables
region: Mapping[str, slice] | Literal["auto"]
for region in [region_slice, "auto"]: # type: ignore
with patch.object(
ZarrStore,
"set_variables",
side_effect=ZarrStore.set_variables,
autospec=True,
) as mock:
ds_region.to_zarr(tmp_path / "test.zarr", region=region, mode="r+")

# should write the data vars but never the index vars with auto mode
for call in mock.call_args_list:
written_variables = call.args[1].keys()
assert "test" in written_variables
assert "x" not in written_variables
assert "y" not in written_variables

def test_zarr_region_append(self, tmp_path):
x = np.arange(0, 50, 10)
Expand Down