diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 2be128b72b0..2e615f3d429 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -127,7 +127,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.1.0 + uses: codecov/codecov-action@v4.1.1 with: file: mypy_report/cobertura.xml flags: mypy @@ -181,7 +181,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.1.0 + uses: codecov/codecov-action@v4.1.1 with: file: mypy_report/cobertura.xml flags: mypy39 @@ -242,7 +242,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v4.1.0 + uses: codecov/codecov-action@v4.1.1 with: file: pyright_report/cobertura.xml flags: pyright @@ -301,7 +301,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v4.1.0 + uses: codecov/codecov-action@v4.1.1 with: file: pyright_report/cobertura.xml flags: pyright39 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 459660e2bfa..da7402a0708 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -151,7 +151,7 @@ jobs: path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v4.1.0 + uses: codecov/codecov-action@v4.1.1 with: file: ./coverage.xml flags: unittests diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 872b2d865fb..f7a98206167 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -143,7 +143,7 @@ jobs: run: | python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.1.0 + uses: codecov/codecov-action@v4.1.1 with: file: mypy_report/cobertura.xml flags: mypy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 74d77e2f2ca..970b2e5e8ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,13 +13,13 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.2.0' + rev: 'v0.3.4' hooks: - id: ruff args: ["--fix", "--show-fixes"] # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.1.1 + rev: 24.3.0 hooks: - id: black-jupyter - repo: https://github.com/keewis/blackdoc @@ -27,10 +27,10 @@ repos: hooks: - id: blackdoc exclude: "generate_aggregations.py" - additional_dependencies: ["black==24.1.1"] + additional_dependencies: ["black==24.3.0"] - id: blackdoc-autoupdate-black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.9.0 hooks: - id: mypy # Copied from setup.cfg diff --git a/doc/user-guide/indexing.rst b/doc/user-guide/indexing.rst index fba9dd585ab..0f575160113 100644 --- a/doc/user-guide/indexing.rst +++ b/doc/user-guide/indexing.rst @@ -748,7 +748,7 @@ Whether array indexing returns a view or a copy of the underlying data depends on the nature of the labels. For positional (integer) -indexing, xarray follows the same rules as NumPy: +indexing, xarray follows the same `rules`_ as NumPy: * Positional indexing with only integers and slices returns a view. * Positional indexing with arrays or lists returns a copy. @@ -765,8 +765,10 @@ Whether data is a copy or a view is more predictable in xarray than in pandas, s unlike pandas, xarray does not produce `SettingWithCopy warnings`_. However, you should still avoid assignment with chained indexing. -.. _SettingWithCopy warnings: https://pandas.pydata.org/pandas-docs/stable/indexing.html#returning-a-view-versus-a-copy +Note that other operations (such as :py:meth:`~xarray.DataArray.values`) may also return views rather than copies. +.. _SettingWithCopy warnings: https://pandas.pydata.org/pandas-docs/stable/indexing.html#returning-a-view-versus-a-copy +.. _rules: https://numpy.org/doc/stable/user/basics.copies.html .. _multi-level indexing: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b81be3c0192..e421eeb3a7f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,32 +15,65 @@ What's New np.random.seed(123456) -.. _whats-new.2024.03.0: +.. _whats-new.2024.04.0: -v2024.03.0 (unreleased) +v2024.04.0 (unreleased) ----------------------- New Features ~~~~~~~~~~~~ + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.2024.03.0: + +v2024.03.0 (Mar 29, 2024) +------------------------- + +This release brings performance improvements for grouped and resampled quantile calculations, CF decoding improvements, +minor optimizations to distributed Zarr writes, and compatibility fixes for Numpy 2.0 and Pandas 3.0. + +Thanks to the 18 contributors to this release: +Anderson Banihirwe, Christoph Hasse, Deepak Cherian, Etienne Schalk, Justus Magin, Kai Mühlbauer, Kevin Schwarzwald, Mark Harfouche, Martin, Matt Savoie, Maximilian Roos, Ray Bell, Roberto Chang, Spencer Clark, Tom Nicholas, crusaderky, owenlittlejohns, saschahofmann + +New Features +~~~~~~~~~~~~ +- Partial writes to existing chunks with ``region`` or ``append_dim`` will now raise an error + (unless ``safe_chunks=False``); previously an error would only be raised on + new variables. (:pull:`8459`, :issue:`8371`, :issue:`8882`) + By `Maximilian Roos `_. +- Grouped and resampling quantile calculations now use the vectorized algorithm in ``flox>=0.9.4`` if present. + By `Deepak Cherian `_. - Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False`` (:issue:`6806`, :pull:`8784`). By `Etienne Schalk `_ and `Deepak Cherian `_. - Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`) By `Anderson Banihirwe `_. - - Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`) By `Anderson Banihirwe `_. - - Expand use of ``.oindex`` and ``.vindex`` properties. (:pull: `8790`) By `Anderson Banihirwe `_ and `Deepak Cherian `_. +- Allow creating :py:class:`xr.Coordinates` objects with no indexes (:pull:`8711`) + By `Benoit Bovy `_ and `Tom Nicholas + `_. +- Enable plotting of ``datetime.dates``. (:issue:`8866`, :pull:`8873`) + By `Sascha Hofmann `_. Breaking changes ~~~~~~~~~~~~~~~~ - - -Deprecations -~~~~~~~~~~~~ +- Don't allow overwriting index variables with ``to_zarr`` region writes. (:issue:`8589`, :pull:`8876`). + By `Deepak Cherian `_. Bug fixes @@ -57,16 +90,29 @@ Bug fixes `CFMaskCoder`/`CFScaleOffsetCoder` (:issue:`2304`, :issue:`5597`, :issue:`7691`, :pull:`8713`, see also discussion in :pull:`7654`). By `Kai Mühlbauer `_. - -Documentation -~~~~~~~~~~~~~ - +- Do not cast `_FillValue`/`missing_value` in `CFMaskCoder` if `_Unsigned` is provided + (:issue:`8844`, :pull:`8852`). +- Adapt handling of copy keyword argument for numpy >= 2.0dev + (:issue:`8844`, :pull:`8851`, :pull:`8865`). + By `Kai Mühlbauer `_. +- Import trapz/trapezoid depending on numpy version + (:issue:`8844`, :pull:`8865`). + By `Kai Mühlbauer `_. +- Warn and return bytes undecoded in case of UnicodeDecodeError in h5netcdf-backend + (:issue:`5563`, :pull:`8874`). + By `Kai Mühlbauer `_. +- Fix bug incorrectly disallowing creation of a dataset with a multidimensional coordinate variable with the same name as one of its dims. + (:issue:`8884`, :pull:`8886`) + By `Tom Nicholas `_. Internal Changes ~~~~~~~~~~~~~~~~ - Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`) By `Matt Savoie `_ and `Tom Nicholas `_. +- Migrates ``datatree`` functionality into ``xarray/core``. (:pull: `8789`) + By `Owen Littlejohns `_, `Matt Savoie + `_ and `Tom Nicholas `_. .. _whats-new.2024.02.0: diff --git a/pyproject.toml b/pyproject.toml index 532dc40e859..751c9085ec8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,7 +171,6 @@ module = [ "xarray.tests.test_dask", "xarray.tests.test_dataarray", "xarray.tests.test_duck_array_ops", - "xarray.tests.test_groupby", "xarray.tests.test_indexing", "xarray.tests.test_merge", "xarray.tests.test_missing", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index d3026a535e2..2f73c38341b 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -69,7 +69,7 @@ T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -1562,9 +1562,7 @@ def _auto_detect_regions(ds, region, open_kwargs): return region -def _validate_and_autodetect_region( - ds, region, mode, open_kwargs -) -> tuple[dict[str, slice], bool]: +def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]: if region == "auto": region = {dim: "auto" for dim in ds.dims} @@ -1572,14 +1570,11 @@ def _validate_and_autodetect_region( raise TypeError(f"``region`` must be a dict, got {type(region)}") if any(v == "auto" for v in region.values()): - region_was_autodetected = True if mode != "r+": raise ValueError( f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}" ) region = _auto_detect_regions(ds, region, open_kwargs) - else: - region_was_autodetected = False for k, v in region.items(): if k not in ds.dims: @@ -1612,7 +1607,7 @@ def _validate_and_autodetect_region( f".drop_vars({non_matching_vars!r})" ) - return region, region_was_autodetected + return region def _validate_datatypes_for_zarr_append(zstore, dataset): @@ -1784,12 +1779,9 @@ def to_zarr( storage_options=storage_options, zarr_version=zarr_version, ) - region, region_was_autodetected = _validate_and_autodetect_region( - dataset, region, mode, open_kwargs - ) - # drop indices to avoid potential race condition with auto region - if region_was_autodetected: - dataset = dataset.drop_vars(dataset.indexes) + region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs) + # can't modify indexed with region writes + dataset = dataset.drop_vars(dataset.indexes) if append_dim is not None and append_dim in region: raise ValueError( f"cannot list the same dimension in both ``append_dim`` and " diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 7d3cc00a52d..f318b4dd42f 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -23,8 +23,8 @@ from netCDF4 import Dataset as ncDataset from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree from xarray.core.types import NestedSequence - from xarray.datatree_.datatree import DataTree # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -137,8 +137,8 @@ def _open_datatree_netcdf( **kwargs, ) -> DataTree: from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree from xarray.core.treenode import NodePath - from xarray.datatree_.datatree import DataTree ds = open_dataset(filename_or_obj, **kwargs) tree_root = DataTree.from_dict({"/": ds}) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index b7c1b2a5f03..71463193939 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -28,6 +28,7 @@ from xarray.core import indexing from xarray.core.utils import ( FrozenDict, + emit_user_level_warning, is_remote_uri, read_magic_number_from_file, try_read_magic_number_from_file_or_path, @@ -39,7 +40,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -58,13 +59,6 @@ def _getitem(self, key): return array[key] -def maybe_decode_bytes(txt): - if isinstance(txt, bytes): - return txt.decode("utf-8") - else: - return txt - - def _read_attributes(h5netcdf_var): # GH451 # to ensure conventions decoding works properly on Python 3, decode all @@ -72,7 +66,16 @@ def _read_attributes(h5netcdf_var): attrs = {} for k, v in h5netcdf_var.attrs.items(): if k not in ["_FillValue", "missing_value"]: - v = maybe_decode_bytes(v) + if isinstance(v, bytes): + try: + v = v.decode("utf-8") + except UnicodeDecodeError: + emit_user_level_warning( + f"'utf-8' codec can't decode bytes for attribute " + f"{k!r} of h5netcdf object {h5netcdf_var.name!r}, " + f"returning bytes undecoded.", + UnicodeWarning, + ) attrs[k] = v return attrs diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 6720a67ae2f..ae86c4ce384 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -45,7 +45,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 154d82bb871..f8c486e512c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -28,6 +28,7 @@ Frozen, FrozenDict, close_on_error, + module_available, try_read_magic_number_from_file_or_path, ) from xarray.core.variable import Variable @@ -39,6 +40,9 @@ from xarray.core.dataset import Dataset +HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0") + + def _decode_string(s): if isinstance(s, bytes): return s.decode("utf-8", "replace") @@ -76,6 +80,12 @@ def __getitem__(self, key): # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. copy = self.datastore.ds.use_mmap + + # adapt handling of copy-kwarg to numpy 2.0 + # see https://github.com/numpy/numpy/issues/25916 + # and https://github.com/numpy/numpy/pull/25922 + copy = None if HAS_NUMPY_2_0 and copy is False else copy + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index e9465dc0ba0..3d6baeefe01 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -34,7 +34,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree # need some special secret attributes to tell us the dimensions @@ -195,7 +195,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): f"Writing this array in parallel with dask could lead to corrupted data." ) if safe_chunks: - raise NotImplementedError( + raise ValueError( base_error + " Consider either rechunking using `chunk()`, deleting " "or modifying `encoding['chunks']`, or specify `safe_chunks=False`." @@ -623,7 +623,12 @@ def store( # avoid needing to load index variables into memory. # TODO: consider making loading indexes lazy again? existing_vars, _, _ = conventions.decode_cf_variables( - self.get_variables(), self.get_attrs() + { + k: v + for k, v in self.get_variables().items() + if k in existing_variable_names + }, + self.get_attrs(), ) # Modified variables must use the same encoding as the store. vars_with_encoding = {} @@ -702,6 +707,17 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} + # We need to do this for both new and existing variables to ensure we're not + # writing to a partial chunk, even though we don't use the `encoding` value + # when writing to an existing variable. See + # https://github.com/pydata/xarray/issues/8371 for details. + encoding = extract_zarr_variable_encoding( + v, + raise_on_invalid=check, + name=vn, + safe_chunks=self._safe_chunks, + ) + if name in existing_keys: # existing variable # TODO: if mode="a", consider overriding the existing variable @@ -732,9 +748,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No zarr_array = self.zarr_group[name] else: # new variable - encoding = extract_zarr_variable_encoding( - v, raise_on_invalid=check, name=vn, safe_chunks=self._safe_chunks - ) encoded_attrs = {} # the magic for storing the hidden dimension data encoded_attrs[DIMENSION_KEY] = dims @@ -1048,8 +1061,8 @@ def open_datatree( import zarr from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree from xarray.core.treenode import NodePath - from xarray.datatree_.datatree import DataTree zds = zarr.open_group(filename_or_obj, mode="r") ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index b3b9d8d1041..db95286f6aa 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -15,10 +15,13 @@ unpack_for_encoding, ) from xarray.core import indexing +from xarray.core.utils import module_available from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array +HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0") + def create_vlen_dtype(element_type): if element_type not in (str, bytes): @@ -156,8 +159,12 @@ def bytes_to_char(arr): def _numpy_bytes_to_char(arr): """Like netCDF4.stringtochar, but faster and more flexible.""" + # adapt handling of copy-kwarg to numpy 2.0 + # see https://github.com/numpy/numpy/issues/25916 + # and https://github.com/numpy/numpy/pull/25922 + copy = None if HAS_NUMPY_2_0 else False # ensure the array is contiguous - arr = np.array(arr, copy=False, order="C", dtype=np.bytes_) + arr = np.array(arr, copy=copy, order="C", dtype=np.bytes_) return arr.reshape(arr.shape + (1,)).view("S1") @@ -199,8 +206,12 @@ def char_to_bytes(arr): def _numpy_char_to_bytes(arr): """Like netCDF4.chartostring, but faster and more flexible.""" + # adapt handling of copy-kwarg to numpy 2.0 + # see https://github.com/numpy/numpy/issues/25916 + # and https://github.com/numpy/numpy/pull/25922 + copy = None if HAS_NUMPY_2_0 else False # based on: http://stackoverflow.com/a/10984878/809705 - arr = np.array(arr, copy=False, order="C") + arr = np.array(arr, copy=copy, order="C") dtype = "S" + str(arr.shape[-1]) return arr.view(dtype).reshape(arr.shape[:-1]) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 92bce0abeaa..466e847e003 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -446,15 +446,7 @@ def format_cftime_datetime(date) -> str: """Converts a cftime.datetime object to a string with the format: YYYY-MM-DD HH:MM:SS.UUUUUU """ - return "{:04d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}.{:06d}".format( - date.year, - date.month, - date.day, - date.hour, - date.minute, - date.second, - date.microsecond, - ) + return f"{date.year:04d}-{date.month:02d}-{date.day:02d} {date.hour:02d}:{date.minute:02d}:{date.second:02d}.{date.microsecond:06d}" def infer_timedelta_units(deltas) -> str: diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 3b11e7bfa02..d31cb6e626a 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -81,9 +81,7 @@ def get_duck_array(self): return self.func(self.array.get_duck_array()) def __repr__(self) -> str: - return "{}({!r}, func={!r}, dtype={!r})".format( - type(self).__name__, self.array, self.func, self.dtype - ) + return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})" class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): @@ -309,6 +307,9 @@ def encode(self, variable: Variable, name: T_Name = None): dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") + # to properly handle _FillValue/missing_value below [a], [b] + # we need to check if unsigned data is written as signed data + unsigned = encoding.get("_Unsigned") is not None fv_exists = fv is not None mv_exists = mv is not None @@ -323,13 +324,19 @@ def encode(self, variable: Variable, name: T_Name = None): if fv_exists: # Ensure _FillValue is cast to same dtype as data's - encoding["_FillValue"] = dtype.type(fv) + # [a] need to skip this if _Unsigned is available + if not unsigned: + encoding["_FillValue"] = dtype.type(fv) fill_value = pop_to(encoding, attrs, "_FillValue", name=name) if mv_exists: # try to use _FillValue, if it exists to align both values # or use missing_value and ensure it's cast to same dtype as data's - encoding["missing_value"] = attrs.get("_FillValue", dtype.type(mv)) + # [b] need to provide mv verbatim if _Unsigned is available + encoding["missing_value"] = attrs.get( + "_FillValue", + (dtype.type(mv) if not unsigned else mv), + ) fill_value = pop_to(encoding, attrs, "missing_value", name=name) # apply fillna @@ -522,7 +529,6 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: def decode(self, variable: Variable, name: T_Name = None) -> Variable: if "_Unsigned" in variable.attrs: dims, data, attrs, encoding = unpack_for_decoding(variable) - unsigned = pop_to(attrs, encoding, "_Unsigned") if data.dtype.kind == "i": diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index bee6afd5a19..96f860b3209 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -2392,8 +2392,6 @@ def count( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2490,8 +2488,6 @@ def all( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2588,8 +2584,6 @@ def any( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2692,8 +2686,6 @@ def max( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2808,8 +2800,6 @@ def min( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2924,8 +2914,6 @@ def mean( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3049,8 +3037,6 @@ def prod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3186,8 +3172,6 @@ def sum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3320,8 +3304,6 @@ def std( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3454,8 +3436,6 @@ def var( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3584,8 +3564,6 @@ def median( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3687,8 +3665,6 @@ def cumsum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3788,8 +3764,6 @@ def cumprod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3919,8 +3893,6 @@ def count( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4017,8 +3989,6 @@ def all( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4115,8 +4085,6 @@ def any( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4219,8 +4187,6 @@ def max( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4335,8 +4301,6 @@ def min( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4451,8 +4415,6 @@ def mean( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4576,8 +4538,6 @@ def prod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4713,8 +4673,6 @@ def sum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4847,8 +4805,6 @@ def std( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4981,8 +4937,6 @@ def var( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5111,8 +5065,6 @@ def median( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5214,8 +5166,6 @@ def cumsum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5315,8 +5265,6 @@ def cumprod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5446,8 +5394,6 @@ def count( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5537,8 +5483,6 @@ def all( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5628,8 +5572,6 @@ def any( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5725,8 +5667,6 @@ def max( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5832,8 +5772,6 @@ def min( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5939,8 +5877,6 @@ def mean( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6055,8 +5991,6 @@ def prod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6181,8 +6115,6 @@ def sum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6304,8 +6236,6 @@ def std( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6427,8 +6357,6 @@ def var( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6546,8 +6474,6 @@ def median( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6641,8 +6567,6 @@ def cumsum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6738,8 +6662,6 @@ def cumprod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6865,8 +6787,6 @@ def count( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -6956,8 +6876,6 @@ def all( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7047,8 +6965,6 @@ def any( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7144,8 +7060,6 @@ def max( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7251,8 +7165,6 @@ def min( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7358,8 +7270,6 @@ def mean( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7474,8 +7384,6 @@ def prod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7600,8 +7508,6 @@ def sum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7723,8 +7629,6 @@ def std( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7846,8 +7750,6 @@ def var( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7965,8 +7867,6 @@ def median( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -8060,8 +7960,6 @@ def cumsum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -8157,8 +8055,6 @@ def cumprod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 452c7115b75..734d7b328de 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -62,10 +62,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if method != "__call__": # TODO: support other methods, e.g., reduce and accumulate. raise NotImplementedError( - "{} method for ufunc {} is not implemented on xarray objects, " + f"{method} method for ufunc {ufunc} is not implemented on xarray objects, " "which currently only support the __call__ method. As an " "alternative, consider explicitly converting xarray objects " - "to NumPy arrays (e.g., with `.values`).".format(method, ufunc) + "to NumPy arrays (e.g., with `.values`)." ) if any(isinstance(o, SupportsArithmetic) for o in out): diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f29f6c4dd35..f09b04b7765 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -133,11 +133,7 @@ def __ne__(self, other): return not self == other def __repr__(self): - return "{}({!r}, {!r})".format( - type(self).__name__, - list(self.input_core_dims), - list(self.output_core_dims), - ) + return f"{type(self).__name__}({list(self.input_core_dims)!r}, {list(self.output_core_dims)!r})" def __str__(self): lhs = ",".join("({})".format(",".join(dims)) for dims in self.input_core_dims) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 2adc4527285..251edd1fc6f 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -298,7 +298,7 @@ def __init__( else: variables = {} for name, data in coords.items(): - var = as_variable(data, name=name) + var = as_variable(data, name=name, auto_convert=False) if var.dims == (name,) and indexes is None: index, index_vars = create_default_index_implicit(var, list(coords)) default_indexes.update({k: index for k in index_vars}) @@ -998,9 +998,12 @@ def create_coords_with_default_indexes( if isinstance(obj, DataArray): dataarray_coords.append(obj.coords) - variable = as_variable(obj, name=name) + variable = as_variable(obj, name=name, auto_convert=False) if variable.dims == (name,): + # still needed to convert to IndexVariable first due to some + # pandas multi-index edge cases. + variable = variable.to_index_variable() idx, idx_vars = create_default_index_implicit(variable, all_variables) indexes.update({k: idx for k in idx_vars}) variables.update(idx_vars) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7a0bdbc4d4c..80dcfe1302c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -159,7 +159,9 @@ def _infer_coords_and_dims( dims = list(coords.keys()) else: for n, (dim, coord) in enumerate(zip(dims, coords)): - coord = as_variable(coord, name=dims[n]).to_index_variable() + coord = as_variable( + coord, name=dims[n], auto_convert=False + ).to_index_variable() dims[n] = coord.name dims_tuple = tuple(dims) if len(dims_tuple) != len(shape): @@ -179,10 +181,12 @@ def _infer_coords_and_dims( new_coords = {} if utils.is_dict_like(coords): for k, v in coords.items(): - new_coords[k] = as_variable(v, name=k) + new_coords[k] = as_variable(v, name=k, auto_convert=False) + if new_coords[k].dims == (k,): + new_coords[k] = new_coords[k].to_index_variable() elif coords is not None: for dim, coord in zip(dims_tuple, coords): - var = as_variable(coord, name=dim) + var = as_variable(coord, name=dim, auto_convert=False) var.dims = (dim,) new_coords[dim] = var.to_index_variable() @@ -204,11 +208,17 @@ def _check_data_shape( return data else: data_shape = tuple( - as_variable(coords[k], k).size if k in coords.keys() else 1 + ( + as_variable(coords[k], k, auto_convert=False).size + if k in coords.keys() + else 1 + ) for k in dims ) else: - data_shape = tuple(as_variable(coord, "foo").size for coord in coords) + data_shape = tuple( + as_variable(coord, "foo", auto_convert=False).size for coord in coords + ) data = np.full(data_shape, data) return data @@ -761,11 +771,15 @@ def data(self, value: Any) -> None: @property def values(self) -> np.ndarray: """ - The array's data as a numpy.ndarray. + The array's data converted to numpy.ndarray. + + This will attempt to convert the array naively using np.array(), + which will raise an error if the array type does not support + coercion like this (e.g. cupy). - If the array's data is not a numpy.ndarray this will attempt to convert - it naively using np.array(), which will raise an error if the array - type does not support coercion like this (e.g. cupy). + Note that this array is not copied; operations on it follow + numpy's rules of what generates a view vs. a copy, and changes + to this array may be reflected in the DataArray as well. """ return self.variable.values @@ -4106,7 +4120,7 @@ def to_zarr( compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, @@ -4126,7 +4140,7 @@ def to_zarr( compute: Literal[False], consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, @@ -4144,7 +4158,7 @@ def to_zarr( compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, @@ -4223,6 +4237,12 @@ def to_zarr( in with ``region``, use a separate call to ``to_zarr()`` with ``compute=False``. See "Appending to existing Zarr stores" in the reference documentation for full details. + + Users are expected to ensure that the specified region aligns with + Zarr chunk boundaries, and that dask chunks are also aligned. + Xarray makes limited checks that these multiple chunk boundaries line up. + It is possible to write incomplete chunks and corrupt the data with this + option if you are not careful. safe_chunks : bool, default: True If True, only allow writes to when there is a many-to-one relationship between Zarr chunks (specified in encoding) and Dask chunks. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b4c00b66ed8..2c0b3e89722 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2452,6 +2452,12 @@ def to_zarr( in with ``region``, use a separate call to ``to_zarr()`` with ``compute=False``. See "Appending to existing Zarr stores" in the reference documentation for full details. + + Users are expected to ensure that the specified region aligns with + Zarr chunk boundaries, and that dask chunks are also aligned. + Xarray makes limited checks that these multiple chunk boundaries line up. + It is possible to write incomplete chunks and corrupt the data with this + option if you are not careful. safe_chunks : bool, default: True If True, only allow writes to when there is a many-to-one relationship between Zarr chunks (specified in encoding) and Dask chunks. @@ -5867,7 +5873,7 @@ def drop_vars( for var in names_set: maybe_midx = self._indexes.get(var, None) if isinstance(maybe_midx, PandasMultiIndex): - idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim]) + idx_coord_names = set(list(maybe_midx.index.names) + [maybe_midx.dim]) idx_other_names = idx_coord_names - set(names_set) other_names.update(idx_other_names) if other_names: diff --git a/xarray/datatree_/datatree/datatree.py b/xarray/core/datatree.py similarity index 89% rename from xarray/datatree_/datatree/datatree.py rename to xarray/core/datatree.py index 10133052185..1b06d87c9fb 100644 --- a/xarray/datatree_/datatree/datatree.py +++ b/xarray/core/datatree.py @@ -2,24 +2,14 @@ import copy import itertools -from collections import OrderedDict +from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Generic, - Hashable, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, NoReturn, - Optional, - Set, - Tuple, Union, overload, ) @@ -31,6 +21,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.treenode import NamedNode, NodePath, Tree from xarray.core.utils import ( Default, Frozen, @@ -40,17 +31,22 @@ maybe_wrap_array, ) from xarray.core.variable import Variable - -from . import formatting, formatting_html -from .common import TreeAttrAccessMixin -from .mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree -from .ops import ( +from xarray.datatree_.datatree.common import TreeAttrAccessMixin +from xarray.datatree_.datatree.formatting import datatree_repr +from xarray.datatree_.datatree.formatting_html import ( + datatree_repr as datatree_repr_html, +) +from xarray.datatree_.datatree.mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) +from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, ) -from .render import RenderTree -from xarray.core.treenode import NamedNode, NodePath, Tree +from xarray.datatree_.datatree.render import RenderTree try: from xarray.core.variable import calculate_dimensions @@ -60,6 +56,7 @@ if TYPE_CHECKING: import pandas as pd + from xarray.core.merge import CoercibleValue from xarray.core.types import ErrorOptions @@ -130,9 +127,9 @@ class DatasetView(Dataset): def __init__( self, - data_vars: Optional[Mapping[Any, Any]] = None, - coords: Optional[Mapping[Any, Any]] = None, - attrs: Optional[Mapping[Any, Any]] = None, + data_vars: Mapping[Any, Any] | None = None, + coords: Mapping[Any, Any] | None = None, + attrs: Mapping[Any, Any] | None = None, ): raise AttributeError("DatasetView objects are not to be initialized directly") @@ -169,33 +166,33 @@ def update(self, other) -> NoReturn: ) # FIXME https://github.com/python/mypy/issues/7328 - @overload - def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc] + @overload # type: ignore[override] + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap] ... @overload - def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[overload-overlap] ... + # See: https://github.com/pydata/xarray/issues/8855 @overload - def __getitem__(self, key: Any) -> Dataset: - ... + def __getitem__(self, key: Any) -> Dataset: ... - def __getitem__(self, key) -> DataArray: + def __getitem__(self, key) -> DataArray | Dataset: # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes # For now just call Dataset.__getitem__ return Dataset.__getitem__(self, key) @classmethod - def _construct_direct( + def _construct_direct( # type: ignore[override] cls, variables: dict[Any, Variable], coord_names: set[Hashable], - dims: Optional[dict[Any, int]] = None, - attrs: Optional[dict] = None, - indexes: Optional[dict[Any, Index]] = None, - encoding: Optional[dict] = None, - close: Optional[Callable[[], None]] = None, + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, + close: Callable[[], None] | None = None, ) -> Dataset: """ Overriding this method (along with ._replace) and modifying it to return a Dataset object @@ -215,13 +212,13 @@ def _construct_direct( obj._encoding = encoding return obj - def _replace( + def _replace( # type: ignore[override] self, - variables: Optional[dict[Hashable, Variable]] = None, - coord_names: Optional[set[Hashable]] = None, - dims: Optional[dict[Any, int]] = None, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: Optional[dict[Hashable, Index]] = None, + indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, ) -> Dataset: @@ -244,7 +241,7 @@ def _replace( inplace=inplace, ) - def map( + def map( # type: ignore[override] self, func: Callable, keep_attrs: bool | None = None, @@ -259,7 +256,7 @@ def map( Function which can be called in the form `func(x, *args, **kwargs)` to transform each DataArray `x` in this dataset into another DataArray. - keep_attrs : bool or None, optional + keep_attrs : bool | None, optional If True, both the dataset's and variables' attributes (`attrs`) will be copied from the original objects to the new ones. If False, the new dataset and variables will be returned without copying the attributes. @@ -293,7 +290,7 @@ def map( bar (x) float64 16B 1.0 2.0 """ - # Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188). + # Copied from xarray.Dataset so as not to call type(self), which causes problems (see https://github.com/xarray-contrib/datatree/issues/188). # TODO Refactor xarray upstream to avoid needing to overwrite this. # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated variables = { @@ -333,21 +330,19 @@ class DataTree( # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from - # TODO __slots__ - # TODO all groupby classes - _name: Optional[str] - _parent: Optional[DataTree] - _children: OrderedDict[str, DataTree] - _attrs: Optional[Dict[Hashable, Any]] - _cache: Dict[str, Any] - _coord_names: Set[Hashable] - _dims: Dict[Hashable, int] - _encoding: Optional[Dict[Hashable, Any]] - _close: Optional[Callable[[], None]] - _indexes: Dict[Hashable, Index] - _variables: Dict[Hashable, Variable] + _name: str | None + _parent: DataTree | None + _children: dict[str, DataTree] + _attrs: dict[Hashable, Any] | None + _cache: dict[str, Any] + _coord_names: set[Hashable] + _dims: dict[Hashable, int] + _encoding: dict[Hashable, Any] | None + _close: Callable[[], None] | None + _indexes: dict[Hashable, Index] + _variables: dict[Hashable, Variable] __slots__ = ( "_name", @@ -365,10 +360,10 @@ class DataTree( def __init__( self, - data: Optional[Dataset | DataArray] = None, - parent: Optional[DataTree] = None, - children: Optional[Mapping[str, DataTree]] = None, - name: Optional[str] = None, + data: Dataset | DataArray | None = None, + parent: DataTree | None = None, + children: Mapping[str, DataTree] | None = None, + name: str | None = None, ): """ Create a single node of a DataTree. @@ -446,7 +441,9 @@ def ds(self) -> DatasetView: return DatasetView._from_node(self) @ds.setter - def ds(self, data: Optional[Union[Dataset, DataArray]] = None) -> None: + def ds(self, data: Dataset | DataArray | None = None) -> None: + # Known mypy issue for setters with different type to property: + # https://github.com/python/mypy/issues/3004 ds = _coerce_to_dataset(data) _check_for_name_collisions(self.children, ds.variables) @@ -515,15 +512,14 @@ def is_hollow(self) -> bool: def variables(self) -> Mapping[Hashable, Variable]: """Low level interface to node contents as dict of Variable objects. - This ordered dictionary is frozen to prevent mutation that could - violate Dataset invariants. It contains all variable objects - constituting this DataTree node, including both data variables and - coordinates. + This dictionary is frozen to prevent mutation that could violate + Dataset invariants. It contains all variable objects constituting this + DataTree node, including both data variables and coordinates. """ return Frozen(self._variables) @property - def attrs(self) -> Dict[Hashable, Any]: + def attrs(self) -> dict[Hashable, Any]: """Dictionary of global attributes on this node object.""" if self._attrs is None: self._attrs = {} @@ -534,7 +530,7 @@ def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property - def encoding(self) -> Dict: + def encoding(self) -> dict: """Dictionary of global encoding attributes on this node object.""" if self._encoding is None: self._encoding = {} @@ -589,7 +585,7 @@ def _item_sources(self) -> Iterable[Mapping[Any, Any]]: # immediate child nodes yield self.children - def _ipython_key_completions_(self) -> List[str]: + def _ipython_key_completions_(self) -> list[str]: """Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. @@ -636,31 +632,31 @@ def __array__(self, dtype=None): "invoking the `to_array()` method." ) - def __repr__(self) -> str: - return formatting.datatree_repr(self) + def __repr__(self) -> str: # type: ignore[override] + return datatree_repr(self) def __str__(self) -> str: - return formatting.datatree_repr(self) + return datatree_repr(self) def _repr_html_(self): """Make html representation of datatree object""" if XR_OPTS["display_style"] == "text": return f"
{escape(repr(self))}
" - return formatting_html.datatree_repr(self) + return datatree_repr_html(self) @classmethod def _construct_direct( cls, variables: dict[Any, Variable], coord_names: set[Hashable], - dims: Optional[dict[Any, int]] = None, - attrs: Optional[dict] = None, - indexes: Optional[dict[Any, Index]] = None, - encoding: Optional[dict] = None, + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, name: str | None = None, parent: DataTree | None = None, - children: Optional[OrderedDict[str, DataTree]] = None, - close: Optional[Callable[[], None]] = None, + children: dict[str, DataTree] | None = None, + close: Callable[[], None] | None = None, ) -> DataTree: """Shortcut around __init__ for internal use when we want to skip costly validation.""" @@ -670,7 +666,7 @@ def _construct_direct( if indexes is None: indexes = {} if children is None: - children = OrderedDict() + children = dict() obj: DataTree = object.__new__(cls) obj._variables = variables @@ -690,15 +686,15 @@ def _construct_direct( def _replace( self: DataTree, - variables: Optional[dict[Hashable, Variable]] = None, - coord_names: Optional[set[Hashable]] = None, - dims: Optional[dict[Any, int]] = None, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: Optional[dict[Hashable, Index]] = None, + indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, name: str | None | Default = _default, - parent: DataTree | None = _default, - children: Optional[OrderedDict[str, DataTree]] = None, + parent: DataTree | None | Default = _default, + children: dict[str, DataTree] | None = None, inplace: bool = False, ) -> DataTree: """ @@ -817,7 +813,7 @@ def _copy_node( """Copy just one node of a tree""" new_node: DataTree = DataTree() new_node.name = self.name - new_node.ds = self.to_dataset().copy(deep=deep) + new_node.ds = self.to_dataset().copy(deep=deep) # type: ignore[assignment] return new_node def __copy__(self: DataTree) -> DataTree: @@ -826,9 +822,9 @@ def __copy__(self: DataTree) -> DataTree: def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: return self._copy_subtree(deep=True, memo=memo) - def get( - self: DataTree, key: str, default: Optional[DataTree | DataArray] = None - ) -> Optional[DataTree | DataArray]: + def get( # type: ignore[override] + self: DataTree, key: str, default: DataTree | DataArray | None = None + ) -> DataTree | DataArray | None: """ Access child nodes, variables, or coordinates stored in this node. @@ -839,7 +835,7 @@ def get( ---------- key : str Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). - default : DataTree | DataArray, optional + default : DataTree | DataArray | None, optional A value to return if the specified key does not exist. Default return value is None. """ if key in self.children: @@ -863,7 +859,7 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: Returns ------- - Union[DataTree, DataArray] + DataTree | DataArray """ # Either: @@ -926,21 +922,38 @@ def __setitem__( else: raise ValueError("Invalid format for key") - def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: + @overload + def update(self, other: Dataset) -> None: ... + + @overload + def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ... + + @overload + def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ... + + def update( + self, + other: ( + Dataset + | Mapping[Hashable, DataArray | Variable] + | Mapping[str, DataTree | DataArray | Variable] + ), + ) -> None: """ Update this node's children and / or variables. Just like `dict.update` this is an in-place operation. """ # TODO separate by type - new_children = {} + new_children: dict[str, DataTree] = {} new_variables = {} for k, v in other.items(): if isinstance(v, DataTree): # avoid named node being stored under inconsistent key - new_child = v.copy() - new_child.name = k - new_children[k] = new_child + new_child: DataTree = v.copy() + # Datatree's name is always a string until we fix that (#8836) + new_child.name = str(k) + new_children[str(k)] = new_child elif isinstance(v, (DataArray, Variable)): # TODO this should also accommodate other types that can be coerced into Variables new_variables[k] = v @@ -949,7 +962,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) # TODO are there any subtleties with preserving order of children like this? - merged_children = OrderedDict({**self.children, **new_children}) + merged_children = {**self.children, **new_children} self._replace( inplace=True, children=merged_children, **vars_merge_result._asdict() ) @@ -1027,16 +1040,16 @@ def drop_nodes( if extra: raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") - children_to_keep = OrderedDict( - {name: child for name, child in self.children.items() if name not in names} - ) + children_to_keep = { + name: child for name, child in self.children.items() if name not in names + } return self._replace(children=children_to_keep) @classmethod def from_dict( cls, d: MutableMapping[str, Dataset | DataArray | DataTree | None], - name: Optional[str] = None, + name: str | None = None, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1050,7 +1063,7 @@ def from_dict( tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. - name : Hashable, optional + name : Hashable | None, optional Name for the root node of the tree. Default is None. Returns @@ -1064,14 +1077,18 @@ def from_dict( # First create the root node root_data = d.pop("/", None) - obj = cls(name=name, data=root_data, parent=None, children=None) + if isinstance(root_data, DataTree): + obj = root_data.copy() + obj.orphan() + else: + obj = cls(name=name, data=root_data, parent=None, children=None) if d: # Populate tree with children determined from data_objects mapping for path, data in d.items(): # Create and set new node node_name = NodePath(path).name - if isinstance(data, cls): + if isinstance(data, DataTree): new_node = data.copy() new_node.orphan() else: @@ -1085,13 +1102,13 @@ def from_dict( return obj - def to_dict(self) -> Dict[str, Dataset]: + def to_dict(self) -> dict[str, Dataset]: """ Create a dictionary mapping of absolute node paths to the data contained in those nodes. Returns ------- - Dict[str, Dataset] + dict[str, Dataset] """ return {node.path: node.to_dataset() for node in self.subtree} @@ -1313,7 +1330,7 @@ def map_over_subtree( func: Callable, *args: Iterable[Any], **kwargs: Any, - ) -> DataTree | Tuple[DataTree]: + ) -> DataTree | tuple[DataTree]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. @@ -1336,13 +1353,13 @@ def map_over_subtree( Returns ------- - subtrees : DataTree, Tuple of DataTrees + subtrees : DataTree, tuple of DataTrees One or more subtrees containing results from applying ``func`` to the data at each node. """ # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error - return map_over_subtree(func)(self, *args, **kwargs) # type: ignore[operator] + return map_over_subtree(func)(self, *args, **kwargs) def map_over_subtree_inplace( self, @@ -1449,8 +1466,8 @@ def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree: # TODO some kind of .collapse() or .flatten() method to merge a subtree - def as_array(self) -> DataArray: - return self.ds.as_dataarray() + def to_dataarray(self) -> DataArray: + return self.ds.to_dataarray() @property def groups(self): @@ -1485,7 +1502,7 @@ def to_netcdf( kwargs : Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` """ - from .io import _datatree_to_netcdf + from xarray.datatree_.datatree.io import _datatree_to_netcdf _datatree_to_netcdf( self, @@ -1515,7 +1532,7 @@ def to_zarr( Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); “a” means override existing variables (create if does not exist); “r+” means modify existing array values only (raise an error if any metadata or shapes would change). The default mode - is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise. + is “w-”. encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., @@ -1527,7 +1544,7 @@ def to_zarr( kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` """ - from .io import _datatree_to_zarr + from xarray.datatree_.datatree.io import _datatree_to_zarr _datatree_to_zarr( self, diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 260dabd9d31..3eed7d02a2e 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -289,8 +289,8 @@ def inline_sparse_repr(array): """Similar to sparse.COO.__repr__, but without the redundant shape/dtype.""" sparse_array_type = array_type("sparse") assert isinstance(array, sparse_array_type), array - return "<{}: nnz={:d}, fill_value={!s}>".format( - type(array).__name__, array.nnz, array.fill_value + return ( + f"<{type(array).__name__}: nnz={array.nnz:d}, fill_value={array.fill_value!s}>" ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3fbfb74d985..5966c32df92 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -28,7 +28,7 @@ filter_indexes_from_coords, safe_cast_to_index, ) -from xarray.core.options import _get_keep_attrs +from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import ( Dims, QuantileMethods, @@ -38,11 +38,13 @@ ) from xarray.core.utils import ( FrozenMappingWarningOnValuesAccess, + contains_only_chunked_or_numpy, either_dict_or_kwargs, emit_user_level_warning, hashable, is_scalar, maybe_wrap_array, + module_available, peek_at, ) from xarray.core.variable import IndexVariable, Variable @@ -1075,6 +1077,9 @@ def _binary_op(self, other, f, reflexive=False): result[var] = result[var].transpose(d, ...) return result + def _restore_dim_order(self, stacked): + raise NotImplementedError + def _maybe_restore_empty_groups(self, combined): """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. @@ -1209,13 +1214,9 @@ def _flox_reduce( (result.sizes[grouper.name],) + var.shape, ) - if isbin: - # Fix dimension order when binning a dimension coordinate - # Needed as long as we do a separate code path for pint; - # For some reason Datasets and DataArrays behave differently! - (group_dim,) = grouper.dims - if isinstance(self._obj, Dataset) and group_dim in self._obj.dims: - result = result.transpose(grouper.name, ...) + if not isinstance(result, Dataset): + # only restore dimension order for arrays + result = self._restore_dim_order(result) return result @@ -1376,16 +1377,30 @@ def quantile( (grouper,) = self.groupers dim = grouper.group1d.dims - return self.map( - self._obj.__class__.quantile, - shortcut=False, - q=q, - dim=dim, - method=method, - keep_attrs=keep_attrs, - skipna=skipna, - interpolation=interpolation, - ) + # Dataset.quantile does this, do it for flox to ensure same output. + q = np.asarray(q, dtype=np.float64) + + if ( + method == "linear" + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) + and module_available("flox", minversion="0.9.4") + ): + result = self._flox_reduce( + func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna + ) + return result + else: + return self.map( + self._obj.__class__.quantile, + shortcut=False, + q=q, + dim=dim, + method=method, + keep_attrs=keep_attrs, + skipna=skipna, + interpolation=interpolation, + ) def where(self, cond, other=dtypes.NA) -> T_Xarray: """Return elements from `self` or `other` depending on `cond`. diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index ea8ae44bb4d..e26c50c8b90 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,7 +4,7 @@ import functools import operator from collections import Counter, defaultdict -from collections.abc import Hashable, Mapping +from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress from dataclasses import dataclass, field from datetime import timedelta @@ -35,6 +35,8 @@ from xarray.core.indexes import Index from xarray.core.variable import Variable + from xarray.namedarray._typing import _Shape, duckarray + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @dataclass @@ -163,7 +165,7 @@ def map_index_queries( obj: T_Xarray, indexers: Mapping[Any, Any], method=None, - tolerance=None, + tolerance: int | float | Iterable[int | float] | None = None, **indexers_kwargs: Any, ) -> IndexSelResult: """Execute index queries from a DataArray / Dataset and label-based indexers @@ -234,17 +236,17 @@ def expanded_indexer(key, ndim): return tuple(new_key) -def _expand_slice(slice_, size): +def _expand_slice(slice_, size: int) -> np.ndarray: return np.arange(*slice_.indices(size)) -def _normalize_slice(sl, size): +def _normalize_slice(sl: slice, size) -> slice: """Ensure that given slice only contains positive start and stop values (stop can be -1 for full-size slices with negative steps, e.g. [-10::-1])""" return slice(*sl.indices(size)) -def slice_slice(old_slice, applied_slice, size): +def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice: """Given a slice and the size of the dimension to which it will be applied, index it with another slice to return a new slice equivalent to applying the slices sequentially @@ -273,7 +275,7 @@ def slice_slice(old_slice, applied_slice, size): return slice(start, stop, step) -def _index_indexer_1d(old_indexer, applied_indexer, size): +def _index_indexer_1d(old_indexer, applied_indexer, size: int): assert isinstance(applied_indexer, integer_types + (slice, np.ndarray)) if isinstance(applied_indexer, slice) and applied_indexer == slice(None): # shortcut for the usual case @@ -282,7 +284,7 @@ def _index_indexer_1d(old_indexer, applied_indexer, size): if isinstance(applied_indexer, slice): indexer = slice_slice(old_indexer, applied_indexer, size) else: - indexer = _expand_slice(old_indexer, size)[applied_indexer] + indexer = _expand_slice(old_indexer, size)[applied_indexer] # type: ignore[assignment] else: indexer = old_indexer[applied_indexer] return indexer @@ -301,16 +303,16 @@ class ExplicitIndexer: __slots__ = ("_key",) - def __init__(self, key): + def __init__(self, key: tuple[Any, ...]): if type(self) is ExplicitIndexer: raise TypeError("cannot instantiate base ExplicitIndexer objects") self._key = tuple(key) @property - def tuple(self): + def tuple(self) -> tuple[Any, ...]: return self._key - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({self.tuple})" @@ -326,18 +328,25 @@ def as_integer_slice(value): class IndexCallable: - """Provide getitem syntax for a callable object.""" + """Provide getitem and setitem syntax for callable objects.""" - __slots__ = ("func",) + __slots__ = ("getter", "setter") - def __init__(self, func): - self.func = func + def __init__( + self, getter: Callable[..., Any], setter: Callable[..., Any] | None = None + ): + self.getter = getter + self.setter = setter - def __getitem__(self, key): - return self.func(key) + def __getitem__(self, key: Any) -> Any: + return self.getter(key) - def __setitem__(self, key, value): - raise NotImplementedError + def __setitem__(self, key: Any, value: Any) -> None: + if self.setter is None: + raise NotImplementedError( + "Setting values is not supported for this indexer." + ) + self.setter(key, value) class BasicIndexer(ExplicitIndexer): @@ -350,7 +359,7 @@ class BasicIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key): + def __init__(self, key: tuple[int | np.integer | slice, ...]): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -366,7 +375,7 @@ def __init__(self, key): ) new_key.append(k) - super().__init__(new_key) + super().__init__(tuple(new_key)) class OuterIndexer(ExplicitIndexer): @@ -380,7 +389,12 @@ class OuterIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key): + def __init__( + self, + key: tuple[ + int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ... + ], + ): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -395,19 +409,19 @@ def __init__(self, key): raise TypeError( f"invalid indexer array, does not have integer dtype: {k!r}" ) - if k.ndim > 1: + if k.ndim > 1: # type: ignore[union-attr] raise TypeError( f"invalid indexer array for {type(self).__name__}; must be scalar " f"or have 1 dimension: {k!r}" ) - k = k.astype(np.int64) + k = k.astype(np.int64) # type: ignore[union-attr] else: raise TypeError( f"unexpected indexer type for {type(self).__name__}: {k!r}" ) new_key.append(k) - super().__init__(new_key) + super().__init__(tuple(new_key)) class VectorizedIndexer(ExplicitIndexer): @@ -422,7 +436,7 @@ class VectorizedIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key): + def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...]): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -443,21 +457,21 @@ def __init__(self, key): f"invalid indexer array, does not have integer dtype: {k!r}" ) if ndim is None: - ndim = k.ndim + ndim = k.ndim # type: ignore[union-attr] elif ndim != k.ndim: ndims = [k.ndim for k in key if isinstance(k, np.ndarray)] raise ValueError( "invalid indexer key: ndarray arguments " f"have different numbers of dimensions: {ndims}" ) - k = k.astype(np.int64) + k = k.astype(np.int64) # type: ignore[union-attr] else: raise TypeError( f"unexpected indexer type for {type(self).__name__}: {k!r}" ) new_key.append(k) - super().__init__(new_key) + super().__init__(tuple(new_key)) class ExplicitlyIndexed: @@ -485,26 +499,40 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: # Note this is the base class for all lazy indexing classes return np.asarray(self.get_duck_array(), dtype=dtype) - def _oindex_get(self, key): - raise NotImplementedError("This method should be overridden") + def _oindex_get(self, indexer: OuterIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_get method should be overridden" + ) - def _vindex_get(self, key): - raise NotImplementedError("This method should be overridden") + def _vindex_get(self, indexer: VectorizedIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_get method should be overridden" + ) + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_set method should be overridden" + ) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_set method should be overridden" + ) - def _check_and_raise_if_non_basic_indexer(self, key): - if isinstance(key, (VectorizedIndexer, OuterIndexer)): + def _check_and_raise_if_non_basic_indexer(self, indexer: ExplicitIndexer) -> None: + if isinstance(indexer, (VectorizedIndexer, OuterIndexer)): raise TypeError( "Vectorized indexing with vectorized or outer indexers is not supported. " "Please use .vindex and .oindex properties to index the array." ) @property - def oindex(self): - return IndexCallable(self._oindex_get) + def oindex(self) -> IndexCallable: + return IndexCallable(self._oindex_get, self._oindex_set) @property - def vindex(self): - return IndexCallable(self._vindex_get) + def vindex(self) -> IndexCallable: + return IndexCallable(self._vindex_get, self._vindex_set) class ImplicitToExplicitIndexingAdapter(NDArrayMixin): @@ -512,7 +540,7 @@ class ImplicitToExplicitIndexingAdapter(NDArrayMixin): __slots__ = ("array", "indexer_cls") - def __init__(self, array, indexer_cls=BasicIndexer): + def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): self.array = as_indexable(array) self.indexer_cls = indexer_cls @@ -522,7 +550,7 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: def get_duck_array(self): return self.array.get_duck_array() - def __getitem__(self, key): + def __getitem__(self, key: Any): key = expanded_indexer(key, self.ndim) indexer = self.indexer_cls(key) @@ -541,7 +569,7 @@ class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "key") - def __init__(self, array, key=None): + def __init__(self, array: Any, key: ExplicitIndexer | None = None): """ Parameters ---------- @@ -553,8 +581,8 @@ def __init__(self, array, key=None): """ if isinstance(array, type(self)) and key is None: # unwrap - key = array.key - array = array.array + key = array.key # type: ignore[has-type] + array = array.array # type: ignore[has-type] if key is None: key = BasicIndexer((slice(None),) * array.ndim) @@ -562,7 +590,7 @@ def __init__(self, array, key=None): self.array = as_indexable(array) self.key = key - def _updated_key(self, new_key): + def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) full_key = [] for size, k in zip(self.array.shape, self.key.tuple): @@ -570,14 +598,14 @@ def _updated_key(self, new_key): full_key.append(k) else: full_key.append(_index_indexer_1d(k, next(iter_new_key), size)) - full_key = tuple(full_key) + full_key_tuple = tuple(full_key) - if all(isinstance(k, integer_types + (slice,)) for k in full_key): - return BasicIndexer(full_key) - return OuterIndexer(full_key) + if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple): + return BasicIndexer(full_key_tuple) + return OuterIndexer(full_key_tuple) @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> _Shape: shape = [] for size, k in zip(self.array.shape, self.key.tuple): if isinstance(k, slice): @@ -605,27 +633,33 @@ def get_duck_array(self): def transpose(self, order): return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order) - def _oindex_get(self, indexer): + def _oindex_get(self, indexer: OuterIndexer): return type(self)(self.array, self._updated_key(indexer)) - def _vindex_get(self, indexer): + def _vindex_get(self, indexer: VectorizedIndexer): array = LazilyVectorizedIndexedArray(self.array, self.key) return array.vindex[indexer] - def __getitem__(self, indexer): + def __getitem__(self, indexer: ExplicitIndexer): self._check_and_raise_if_non_basic_indexer(indexer) return type(self)(self.array, self._updated_key(indexer)) - def __setitem__(self, key, value): - if isinstance(key, VectorizedIndexer): - raise NotImplementedError( - "Lazy item assignment with the vectorized indexer is not yet " - "implemented. Load your data first by .load() or compute()." - ) + def _vindex_set(self, key: VectorizedIndexer, value: Any) -> None: + raise NotImplementedError( + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) + + def _oindex_set(self, key: OuterIndexer, value: Any) -> None: + full_key = self._updated_key(key) + self.array.oindex[full_key] = value + + def __setitem__(self, key: BasicIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(key) full_key = self._updated_key(key) self.array[full_key] = value - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" @@ -638,7 +672,7 @@ class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "key") - def __init__(self, array, key): + def __init__(self, array: duckarray[Any, Any], key: ExplicitIndexer): """ Parameters ---------- @@ -648,16 +682,15 @@ def __init__(self, array, key): """ if isinstance(key, (BasicIndexer, OuterIndexer)): self.key = _outer_to_vectorized_indexer(key, array.shape) - else: + elif isinstance(key, VectorizedIndexer): self.key = _arrayize_vectorized_indexer(key, array.shape) self.array = as_indexable(array) @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): array = apply_indexer(self.array, self.key) else: @@ -672,16 +705,16 @@ def get_duck_array(self): array = array.get_duck_array() return _wrap_numpy_scalars(array) - def _updated_key(self, new_key): + def _updated_key(self, new_key: ExplicitIndexer): return _combine_indexers(self.key, self.shape, new_key) - def _oindex_get(self, indexer): + def _oindex_get(self, indexer: OuterIndexer): return type(self)(self.array, self._updated_key(indexer)) - def _vindex_get(self, indexer): + def _vindex_get(self, indexer: VectorizedIndexer): return type(self)(self.array, self._updated_key(indexer)) - def __getitem__(self, indexer): + def __getitem__(self, indexer: ExplicitIndexer): self._check_and_raise_if_non_basic_indexer(indexer) # If the indexed array becomes a scalar, return LazilyIndexedArray if all(isinstance(ind, integer_types) for ind in indexer.tuple): @@ -693,13 +726,13 @@ def transpose(self, order): key = VectorizedIndexer(tuple(k.transpose(order) for k in self.key.tuple)) return type(self)(self.array, key) - def __setitem__(self, key, value): + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: raise NotImplementedError( "Lazy item assignment with the vectorized indexer is not yet " "implemented. Load your data first by .load() or compute()." ) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" @@ -714,7 +747,7 @@ def _wrap_numpy_scalars(array): class CopyOnWriteArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "_copied") - def __init__(self, array): + def __init__(self, array: duckarray[Any, Any]): self.array = as_indexable(array) self._copied = False @@ -726,22 +759,32 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() - def _oindex_get(self, key): - return type(self)(_wrap_numpy_scalars(self.array.oindex[key])) + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) - def _vindex_get(self, key): - return type(self)(_wrap_numpy_scalars(self.array.vindex[key])) + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) - def __getitem__(self, key): - self._check_and_raise_if_non_basic_indexer(key) - return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(_wrap_numpy_scalars(self.array[indexer])) def transpose(self, order): return self.array.transpose(order) - def __setitem__(self, key, value): + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self._ensure_copied() + self.array.vindex[indexer] = value + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self._ensure_copied() + self.array.oindex[indexer] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) self._ensure_copied() - self.array[key] = value + + self.array[indexer] = value def __deepcopy__(self, memo): # CopyOnWriteArray is used to wrap backend array objects, which might @@ -766,21 +809,28 @@ def get_duck_array(self): self._ensure_cached() return self.array.get_duck_array() - def _oindex_get(self, key): - return type(self)(_wrap_numpy_scalars(self.array.oindex[key])) + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) - def _vindex_get(self, key): - return type(self)(_wrap_numpy_scalars(self.array.vindex[key])) + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) - def __getitem__(self, key): - self._check_and_raise_if_non_basic_indexer(key) - return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(_wrap_numpy_scalars(self.array[indexer])) def transpose(self, order): return self.array.transpose(order) - def __setitem__(self, key, value): - self.array[key] = value + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self.array.vindex[indexer] = value + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self.array.oindex[indexer] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer] = value def as_indexable(array): @@ -805,12 +855,14 @@ def as_indexable(array): raise TypeError(f"Invalid array type: {type(array)}") -def _outer_to_vectorized_indexer(key, shape): +def _outer_to_vectorized_indexer( + indexer: BasicIndexer | OuterIndexer, shape: _Shape +) -> VectorizedIndexer: """Convert an OuterIndexer into an vectorized indexer. Parameters ---------- - key : Outer/Basic Indexer + indexer : Outer/Basic Indexer An indexer to convert. shape : tuple Shape of the array subject to the indexing. @@ -822,7 +874,7 @@ def _outer_to_vectorized_indexer(key, shape): Each element is an array: broadcasting them together gives the shape of the result. """ - key = key.tuple + key = indexer.tuple n_dim = len([k for k in key if not isinstance(k, integer_types)]) i_dim = 0 @@ -834,18 +886,18 @@ def _outer_to_vectorized_indexer(key, shape): if isinstance(k, slice): k = np.arange(*k.indices(size)) assert k.dtype.kind in {"i", "u"} - shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)] - new_key.append(k.reshape(*shape)) + new_shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)] + new_key.append(k.reshape(*new_shape)) i_dim += 1 return VectorizedIndexer(tuple(new_key)) -def _outer_to_numpy_indexer(key, shape): +def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: _Shape): """Convert an OuterIndexer into an indexer for NumPy. Parameters ---------- - key : Basic/OuterIndexer + indexer : Basic/OuterIndexer An indexer to convert. shape : tuple Shape of the array subject to the indexing. @@ -855,16 +907,16 @@ def _outer_to_numpy_indexer(key, shape): tuple Tuple suitable for use to index a NumPy array. """ - if len([k for k in key.tuple if not isinstance(k, slice)]) <= 1: + if len([k for k in indexer.tuple if not isinstance(k, slice)]) <= 1: # If there is only one vector and all others are slice, # it can be safely used in mixed basic/advanced indexing. # Boolean index should already be converted to integer array. - return key.tuple + return indexer.tuple else: - return _outer_to_vectorized_indexer(key, shape).tuple + return _outer_to_vectorized_indexer(indexer, shape).tuple -def _combine_indexers(old_key, shape, new_key): +def _combine_indexers(old_key, shape: _Shape, new_key) -> VectorizedIndexer: """Combine two indexers. Parameters @@ -906,9 +958,9 @@ class IndexingSupport(enum.Enum): def explicit_indexing_adapter( key: ExplicitIndexer, - shape: tuple[int, ...], + shape: _Shape, indexing_support: IndexingSupport, - raw_indexing_method: Callable, + raw_indexing_method: Callable[..., Any], ) -> Any: """Support explicit indexing by delegating to a raw indexing method. @@ -940,7 +992,7 @@ def explicit_indexing_adapter( return result -def apply_indexer(indexable, indexer): +def apply_indexer(indexable, indexer: ExplicitIndexer): """Apply an indexer to an indexable object.""" if isinstance(indexer, VectorizedIndexer): return indexable.vindex[indexer] @@ -950,8 +1002,18 @@ def apply_indexer(indexable, indexer): return indexable[indexer] +def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) -> None: + """Set values in an indexable object using an indexer.""" + if isinstance(indexer, VectorizedIndexer): + indexable.vindex[indexer] = value + elif isinstance(indexer, OuterIndexer): + indexable.oindex[indexer] = value + else: + indexable[indexer] = value + + def decompose_indexer( - indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport + indexer: ExplicitIndexer, shape: _Shape, indexing_support: IndexingSupport ) -> tuple[ExplicitIndexer, ExplicitIndexer]: if isinstance(indexer, VectorizedIndexer): return _decompose_vectorized_indexer(indexer, shape, indexing_support) @@ -990,7 +1052,7 @@ def _decompose_slice(key: slice, size: int) -> tuple[slice, slice]: def _decompose_vectorized_indexer( indexer: VectorizedIndexer, - shape: tuple[int, ...], + shape: _Shape, indexing_support: IndexingSupport, ) -> tuple[ExplicitIndexer, ExplicitIndexer]: """ @@ -1072,7 +1134,7 @@ def _decompose_vectorized_indexer( def _decompose_outer_indexer( indexer: BasicIndexer | OuterIndexer, - shape: tuple[int, ...], + shape: _Shape, indexing_support: IndexingSupport, ) -> tuple[ExplicitIndexer, ExplicitIndexer]: """ @@ -1213,7 +1275,9 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) -def _arrayize_vectorized_indexer(indexer, shape): +def _arrayize_vectorized_indexer( + indexer: VectorizedIndexer, shape: _Shape +) -> VectorizedIndexer: """Return an identical vindex but slices are replaced by arrays""" slices = [v for v in indexer.tuple if isinstance(v, slice)] if len(slices) == 0: @@ -1233,7 +1297,9 @@ def _arrayize_vectorized_indexer(indexer, shape): return VectorizedIndexer(tuple(new_key)) -def _chunked_array_with_chunks_hint(array, chunks, chunkmanager): +def _chunked_array_with_chunks_hint( + array, chunks, chunkmanager: ChunkManagerEntrypoint[Any] +): """Create a chunked array using the chunks hint for dimensions of size > 1.""" if len(chunks) < array.ndim: @@ -1241,21 +1307,21 @@ def _chunked_array_with_chunks_hint(array, chunks, chunkmanager): new_chunks = [] for chunk, size in zip(chunks, array.shape): new_chunks.append(chunk if size > 1 else (1,)) - return chunkmanager.from_array(array, new_chunks) + return chunkmanager.from_array(array, new_chunks) # type: ignore[arg-type] def _logical_any(args): return functools.reduce(operator.or_, args) -def _masked_result_drop_slice(key, data=None): +def _masked_result_drop_slice(key, data: duckarray[Any, Any] | None = None): key = (k for k in key if not isinstance(k, slice)) chunks_hint = getattr(data, "chunks", None) new_keys = [] for k in key: if isinstance(k, np.ndarray): - if is_chunked_array(data): + if is_chunked_array(data): # type: ignore[arg-type] chunkmanager = get_chunked_array_type(data) new_keys.append( _chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager) @@ -1273,7 +1339,9 @@ def _masked_result_drop_slice(key, data=None): return mask -def create_mask(indexer, shape, data=None): +def create_mask( + indexer: ExplicitIndexer, shape: _Shape, data: duckarray[Any, Any] | None = None +): """Create a mask for indexing with a fill-value. Parameters @@ -1318,7 +1386,9 @@ def create_mask(indexer, shape, data=None): return mask -def _posify_mask_subindexer(index): +def _posify_mask_subindexer( + index: np.ndarray[Any, np.dtype[np.generic]], +) -> np.ndarray[Any, np.dtype[np.generic]]: """Convert masked indices in a flat array to the nearest unmasked index. Parameters @@ -1344,7 +1414,7 @@ def _posify_mask_subindexer(index): return new_index -def posify_mask_indexer(indexer): +def posify_mask_indexer(indexer: ExplicitIndexer) -> ExplicitIndexer: """Convert masked values (-1) in an indexer to nearest unmasked values. This routine is useful for dask, where it can be much faster to index @@ -1399,45 +1469,31 @@ def __init__(self, array): ) self.array = array - def _indexing_array_and_key(self, key): - if isinstance(key, OuterIndexer): - array = self.array - key = _outer_to_numpy_indexer(key, self.array.shape) - elif isinstance(key, VectorizedIndexer): - array = NumpyVIndexAdapter(self.array) - key = key.tuple - elif isinstance(key, BasicIndexer): - array = self.array - # We want 0d slices rather than scalars. This is achieved by - # appending an ellipsis (see - # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). - key = key.tuple + (Ellipsis,) - else: - raise TypeError(f"unexpected key type: {type(key)}") - - return array, key - def transpose(self, order): return self.array.transpose(order) - def _oindex_get(self, key): - key = _outer_to_numpy_indexer(key, self.array.shape) + def _oindex_get(self, indexer: OuterIndexer): + key = _outer_to_numpy_indexer(indexer, self.array.shape) return self.array[key] - def _vindex_get(self, key): + def _vindex_get(self, indexer: VectorizedIndexer): array = NumpyVIndexAdapter(self.array) - return array[key.tuple] + return array[indexer.tuple] - def __getitem__(self, key): - self._check_and_raise_if_non_basic_indexer(key) - array, key = self._indexing_array_and_key(key) + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = indexer.tuple + (Ellipsis,) return array[key] - def __setitem__(self, key, value): - array, key = self._indexing_array_and_key(key) + def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None: try: array[key] = value - except ValueError: + except ValueError as exc: # More informative exception if read-only view if not array.flags.writeable and not array.flags.owndata: raise ValueError( @@ -1445,7 +1501,24 @@ def __setitem__(self, key, value): "Do you want to .copy() array first?" ) else: - raise + raise exc + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + key = _outer_to_numpy_indexer(indexer, self.array.shape) + self._safe_setitem(self.array, key, value) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + array = NumpyVIndexAdapter(self.array) + self._safe_setitem(array, indexer.tuple, value) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = indexer.tuple + (Ellipsis,) + self._safe_setitem(array, key, value) class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): @@ -1473,28 +1546,30 @@ def __init__(self, array): ) self.array = array - def _oindex_get(self, key): + def _oindex_get(self, indexer: OuterIndexer): # manual orthogonal indexing (implemented like DaskIndexingAdapter) - key = key.tuple + key = indexer.tuple value = self.array for axis, subkey in reversed(list(enumerate(key))): value = value[(slice(None),) * axis + (subkey, Ellipsis)] return value - def _vindex_get(self, key): + def _vindex_get(self, indexer: VectorizedIndexer): raise TypeError("Vectorized indexing is not supported") - def __getitem__(self, key): - self._check_and_raise_if_non_basic_indexer(key) - return self.array[key.tuple] + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return self.array[indexer.tuple] - def __setitem__(self, key, value): - if isinstance(key, (BasicIndexer, OuterIndexer)): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - raise TypeError("Vectorized indexing is not supported") - else: - raise TypeError(f"Unrecognized indexer: {key}") + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self.array[indexer.tuple] = value + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError("Vectorized indexing is not supported") + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer.tuple] = value def transpose(self, order): xp = self.array.__array_namespace__() @@ -1512,8 +1587,8 @@ def __init__(self, array): """ self.array = array - def _oindex_get(self, key): - key = key.tuple + def _oindex_get(self, indexer: OuterIndexer): + key = indexer.tuple try: return self.array[key] except NotImplementedError: @@ -1523,26 +1598,27 @@ def _oindex_get(self, key): value = value[(slice(None),) * axis + (subkey,)] return value - def _vindex_get(self, key): - return self.array.vindex[key.tuple] + def _vindex_get(self, indexer: VectorizedIndexer): + return self.array.vindex[indexer.tuple] - def __getitem__(self, key): - self._check_and_raise_if_non_basic_indexer(key) - return self.array[key.tuple] + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return self.array[indexer.tuple] - def __setitem__(self, key, value): - if isinstance(key, BasicIndexer): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - self.array.vindex[key.tuple] = value - elif isinstance(key, OuterIndexer): - num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple) - if num_non_slices > 1: - raise NotImplementedError( - "xarray can't set arrays with multiple " - "array indices to dask yet." - ) - self.array[key.tuple] = value + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in indexer.tuple) + if num_non_slices > 1: + raise NotImplementedError( + "xarray can't set arrays with multiple " "array indices to dask yet." + ) + self.array[indexer.tuple] = value + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self.array.vindex[indexer.tuple] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer.tuple] = value def transpose(self, order): return self.array.transpose(order) @@ -1581,7 +1657,7 @@ def get_duck_array(self) -> np.ndarray: return np.asarray(self) @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> _Shape: return (len(self.array),) def _convert_scalar(self, item): @@ -1604,14 +1680,68 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) - def _oindex_get(self, key): - return self.__getitem__(key) + def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: + if isinstance(key, tuple) and len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key + + return key + + def _handle_result( + self, result: Any + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + if isinstance(result, pd.Index): + return type(self)(result, dtype=self.dtype) + else: + return self._convert_scalar(result) + + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable.oindex[indexer] + + result = self.array[key] + + return self._handle_result(result) + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable.vindex[indexer] - def _vindex_get(self, key): - return self.__getitem__(key) + result = self.array[key] + + return self._handle_result(result) def __getitem__( - self, indexer + self, indexer: ExplicitIndexer ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1619,22 +1749,15 @@ def __getitem__( | np.datetime64 | np.timedelta64 ): - key = indexer.tuple - if isinstance(key, tuple) and len(key) == 1: - # unpack key so it can index a pandas.Index object (pandas.Index - # objects don't like tuples) - (key,) = key + key = self._prepare_key(indexer.tuple) if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional indexable = NumpyIndexingAdapter(np.asarray(self)) - return apply_indexer(indexable, indexer) + return indexable[indexer] result = self.array[key] - if isinstance(result, pd.Index): - return type(self)(result, dtype=self.dtype) - else: - return self._convert_scalar(result) + return self._handle_result(result) def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional @@ -1690,7 +1813,35 @@ def _convert_scalar(self, item): item = item[idx] return super()._convert_scalar(item) - def __getitem__(self, indexer): + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._oindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._vindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + + def __getitem__(self, indexer: ExplicitIndexer): result = super().__getitem__(indexer) if isinstance(result, type(self)): result.level = self.level diff --git a/xarray/core/merge.py b/xarray/core/merge.py index a689620e524..a90e59e7c0b 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -355,7 +355,7 @@ def append_all(variables, indexes): indexes_.pop(name, None) append_all(coords_, indexes_) - variable = as_variable(variable, name=name) + variable = as_variable(variable, name=name, auto_convert=False) if name in indexes: append(name, variable, indexes[name]) elif variable.dims == (name,): @@ -562,25 +562,6 @@ def merge_coords( return variables, out_indexes -def assert_valid_explicit_coords( - variables: Mapping[Any, Any], - dims: Mapping[Any, int], - explicit_coords: Iterable[Hashable], -) -> None: - """Validate explicit coordinate names/dims. - - Raise a MergeError if an explicit coord shares a name with a dimension - but is comprised of arbitrary dimensions. - """ - for coord_name in explicit_coords: - if coord_name in dims and variables[coord_name].dims != (coord_name,): - raise MergeError( - f"coordinate {coord_name} shares a name with a dataset dimension, but is " - "not a 1D variable along that dimension. This is disallowed " - "by the xarray data model." - ) - - def merge_attrs(variable_attrs, combine_attrs, context=None): """Combine attributes from different variables according to combine_attrs""" if not variable_attrs: @@ -728,7 +709,6 @@ def merge_core( # coordinates may be dropped in merged results coord_names.intersection_update(variables) if explicit_coords is not None: - assert_valid_explicit_coords(variables, dims, explicit_coords) coord_names.update(explicit_coords) for dim, size in dims.items(): if dim in variables: diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index b3e6e43f306..8cee3f69d70 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -230,7 +230,7 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: pass def _iter_parents(self: Tree) -> Iterator[Tree]: - """Iterate up the tree, starting from the current node.""" + """Iterate up the tree, starting from the current node's parent.""" node: Tree | None = self.parent while node is not None: yield node diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a03e93ac699..ec284e411fc 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -33,6 +33,7 @@ decode_numpy_dict_values, drop_dims_from_indexers, either_dict_or_kwargs, + emit_user_level_warning, ensure_us_time_resolution, infix_dims, is_dict_like, @@ -80,7 +81,9 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? -def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable: +def as_variable( + obj: T_DuckArray | Any, name=None, auto_convert: bool = True +) -> Variable | IndexVariable: """Convert an object into a Variable. Parameters @@ -100,6 +103,9 @@ def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable: along a dimension of this given name. - Variables with name matching one of their dimensions are converted into `IndexVariable` objects. + auto_convert : bool, optional + For internal use only! If True, convert a "dimension" variable into + an IndexVariable object (deprecated). Returns ------- @@ -150,9 +156,15 @@ def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable: f"explicit list of dimensions: {obj!r}" ) - if name is not None and name in obj.dims and obj.ndim == 1: - # automatically convert the Variable into an Index - obj = obj.to_index_variable() + if auto_convert: + if name is not None and name in obj.dims and obj.ndim == 1: + # automatically convert the Variable into an Index + emit_user_level_warning( + f"variable {name!r} with name matching its dimension will not be " + "automatically converted into an `IndexVariable` object in the future.", + FutureWarning, + ) + obj = obj.to_index_variable() return obj @@ -209,7 +221,14 @@ def _possibly_convert_objects(values): as_series = pd.Series(values.ravel(), copy=False) if as_series.dtype.kind in "mM": as_series = _as_nanosecond_precision(as_series) - return np.asarray(as_series).reshape(values.shape) + result = np.asarray(as_series).reshape(values.shape) + if not result.flags.writeable: + # GH8843, pandas copy-on-write mode creates read-only arrays by default + try: + result.flags.writeable = True + except ValueError: + result = result.copy() + return result def _possibly_convert_datetime_or_timedelta_index(data): @@ -699,8 +718,10 @@ def _broadcast_indexes_vectorized(self, key): variable = ( value if isinstance(value, Variable) - else as_variable(value, name=dim) + else as_variable(value, name=dim, auto_convert=False) ) + if variable.dims == (dim,): + variable = variable.to_index_variable() if variable.dtype.kind == "b": # boolean indexing case (variable,) = variable._nonzero() @@ -842,7 +863,7 @@ def __setitem__(self, key, value): value = np.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) - indexable[index_tuple] = value + indexing.set_with_indexer(indexable, index_tuple, value) @property def encoding(self) -> dict[Any, Any]: diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index ae9521309e0..8cb90ac1b2b 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -84,7 +84,7 @@ method supported by this weighted version corresponds to the default "linear" option of ``numpy.quantile``. This is "Type 7" option, described in Hyndman and Fan (1996) [2]_. The implementation is largely inspired by a blog post - from A. Akinshin's [3]_. + from A. Akinshin's (2023) [3]_. Parameters ---------- @@ -122,7 +122,8 @@ .. [1] https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/ .. [2] Hyndman, R. J. & Fan, Y. (1996). Sample Quantiles in Statistical Packages. The American Statistician, 50(4), 361–365. https://doi.org/10.2307/2684934 - .. [3] https://aakinshin.net/posts/weighted-quantiles + .. [3] Akinshin, A. (2023) "Weighted quantile estimators" arXiv:2304.07265 [stat.ME] + https://arxiv.org/abs/2304.07265 """ diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index 071dcbecf8c..f2603b64641 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,15 +1,11 @@ # import public API -from .datatree import DataTree -from .extensions import register_datatree_accessor from .mapping import TreeIsomorphismError, map_over_subtree from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError __all__ = ( - "DataTree", "TreeIsomorphismError", "InvalidTreeError", "NotFoundInTreeError", "map_over_subtree", - "register_datatree_accessor", ) diff --git a/xarray/datatree_/datatree/extensions.py b/xarray/datatree_/datatree/extensions.py index f6f4e985a79..bf888fc4484 100644 --- a/xarray/datatree_/datatree/extensions.py +++ b/xarray/datatree_/datatree/extensions.py @@ -1,6 +1,6 @@ from xarray.core.extensions import _register_accessor -from .datatree import DataTree +from xarray.core.datatree import DataTree def register_datatree_accessor(name): diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py index deba57eb09d..9ebee72d4ef 100644 --- a/xarray/datatree_/datatree/formatting.py +++ b/xarray/datatree_/datatree/formatting.py @@ -2,11 +2,11 @@ from xarray.core.formatting import _compat_to_str, diff_dataset_repr -from .mapping import diff_treestructure -from .render import RenderTree +from xarray.datatree_.datatree.mapping import diff_treestructure +from xarray.datatree_.datatree.render import RenderTree if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree def diff_nodewise_summary(a, b, compat): diff --git a/xarray/datatree_/datatree/io.py b/xarray/datatree_/datatree/io.py index d3d533ee71e..48335ddca70 100644 --- a/xarray/datatree_/datatree/io.py +++ b/xarray/datatree_/datatree/io.py @@ -1,4 +1,4 @@ -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree def _get_nc_dataset_class(engine): diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/datatree_/datatree/mapping.py index 355149060a9..7742ece9738 100644 --- a/xarray/datatree_/datatree/mapping.py +++ b/xarray/datatree_/datatree/mapping.py @@ -156,7 +156,7 @@ def map_over_subtree(func: Callable) -> Callable: @functools.wraps(func) def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" - from .datatree import DataTree + from xarray.core.datatree import DataTree all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ a for a in kwargs.values() if isinstance(a, DataTree) diff --git a/xarray/datatree_/datatree/render.py b/xarray/datatree_/datatree/render.py index aef327c5c47..e6af9c85ee8 100644 --- a/xarray/datatree_/datatree/render.py +++ b/xarray/datatree_/datatree/render.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree Row = collections.namedtuple("Row", ("pre", "fill", "node")) diff --git a/xarray/datatree_/datatree/testing.py b/xarray/datatree_/datatree/testing.py index 1cbcdf2d4e3..bf54116725a 100644 --- a/xarray/datatree_/datatree/testing.py +++ b/xarray/datatree_/datatree/testing.py @@ -1,6 +1,6 @@ from xarray.testing.assertions import ensure_warnings -from .datatree import DataTree +from xarray.core.datatree import DataTree from .formatting import diff_tree_repr diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py index bd2e7ba3247..53a9a72239d 100644 --- a/xarray/datatree_/datatree/tests/conftest.py +++ b/xarray/datatree_/datatree/tests/conftest.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree @pytest.fixture(scope="module") diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py index c3eb74451a6..4ca532ebba4 100644 --- a/xarray/datatree_/datatree/tests/test_dataset_api.py +++ b/xarray/datatree_/datatree/tests/test_dataset_api.py @@ -1,7 +1,7 @@ import numpy as np import xarray as xr -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.testing import assert_equal diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py index 0241e496abf..fb2e82453ec 100644 --- a/xarray/datatree_/datatree/tests/test_extensions.py +++ b/xarray/datatree_/datatree/tests/test_extensions.py @@ -1,6 +1,7 @@ import pytest -from xarray.datatree_.datatree import DataTree, register_datatree_accessor +from xarray.core.datatree import DataTree +from xarray.datatree_.datatree.extensions import register_datatree_accessor class TestAccessor: diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py index b58c02282e7..77f8346ae72 100644 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ b/xarray/datatree_/datatree/tests/test_formatting.py @@ -2,7 +2,7 @@ from xarray import Dataset -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.formatting import diff_tree_repr diff --git a/xarray/datatree_/datatree/tests/test_formatting_html.py b/xarray/datatree_/datatree/tests/test_formatting_html.py index 943bbab4154..98cdf02bff4 100644 --- a/xarray/datatree_/datatree/tests/test_formatting_html.py +++ b/xarray/datatree_/datatree/tests/test_formatting_html.py @@ -1,7 +1,8 @@ import pytest import xarray as xr -from xarray.datatree_.datatree import DataTree, formatting_html +from xarray.core.datatree import DataTree +from xarray.datatree_.datatree import formatting_html @pytest.fixture(scope="module", params=["some html", "some other html"]) diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/datatree_/datatree/tests/test_mapping.py index 53d6e085440..c6cd04887c0 100644 --- a/xarray/datatree_/datatree/tests/test_mapping.py +++ b/xarray/datatree_/datatree/tests/test_mapping.py @@ -2,7 +2,7 @@ import pytest import xarray as xr -from xarray.datatree_.datatree.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree from xarray.datatree_.datatree.testing import assert_equal diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index fd209bc273f..135dabc0656 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -813,7 +813,7 @@ def chunk( # Using OuterIndexer is a pragmatic choice: dask does not yet handle # different indexing types in an explicit way: # https://github.com/dask/dask/issues/2883 - ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[no-untyped-call, assignment] + ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] if is_dict_like(chunks): chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) # type: ignore[assignment] diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 8386161bf29..ed752d3461f 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -1848,9 +1848,10 @@ def _center_pixels(x): # missing data transparent. We therefore add an alpha channel if # there isn't one, and set it to transparent where data is masked. if z.shape[-1] == 3: - alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype) + safe_dtype = np.promote_types(z.dtype, np.uint8) + alpha = np.ma.ones(z.shape[:2] + (1,), dtype=safe_dtype) if np.issubdtype(z.dtype, np.integer): - alpha *= 255 + alpha[:] = 255 z = np.ma.concatenate((z, alpha), axis=2) else: z = z.copy() diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 804e1cfd795..8789bc2f9c2 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -4,7 +4,7 @@ import textwrap import warnings from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence -from datetime import datetime +from datetime import date, datetime from inspect import getfullargspec from typing import TYPE_CHECKING, Any, Callable, Literal, overload @@ -672,7 +672,7 @@ def _ensure_plottable(*args) -> None: np.bool_, np.str_, ) - other_types: tuple[type[object], ...] = (datetime,) + other_types: tuple[type[object], ...] = (datetime, date) cftime_datetime_types: tuple[type[object], ...] = ( () if cftime is None else (cftime.datetime,) ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 2e6e638f5b1..5007db9eeb2 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -20,6 +20,7 @@ from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options +from xarray.core.variable import IndexVariable from xarray.testing import ( # noqa: F401 assert_chunks_equal, assert_duckarray_allclose, @@ -47,6 +48,15 @@ ) +def assert_writeable(ds): + readonly = [ + name + for name, var in ds.variables.items() + if not isinstance(var, IndexVariable) and not var.data.flags.writeable + ] + assert not readonly, readonly + + def _importorskip( modname: str, minversion: str | None = None ) -> tuple[bool, pytest.MarkDecorator]: @@ -326,7 +336,7 @@ def create_test_data( numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64") obj.coords["numbers"] = ("dim3", numbers_values) obj.encoding = {"foo": "bar"} - assert all(obj.data.flags.writeable for obj in obj.variables.values()) + assert_writeable(obj) return obj diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 8590c9fb4e7..a32b0e08bea 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -6,7 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.tests import create_test_data, requires_dask diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index b97d5ced938..be9b3ef0422 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -13,12 +13,12 @@ import tempfile import uuid import warnings -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Mapping from contextlib import ExitStack from io import BytesIO from os import listdir from pathlib import Path -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final, Literal, cast from unittest.mock import patch import numpy as np @@ -2261,7 +2261,6 @@ def test_write_uneven_dask_chunks(self) -> None: original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3}) with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: for k, v in actual.data_vars.items(): - print(k) assert v.chunks == actual[k].chunks def test_chunk_encoding(self) -> None: @@ -2305,7 +2304,7 @@ def test_chunk_encoding_with_dask(self) -> None: # should fail if encoding["chunks"] clashes with dask_chunks badenc = ds.chunk({"x": 4}) badenc.var1.encoding["chunks"] = (6,) - with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"): + with pytest.raises(ValueError, match=r"named 'var1' would overlap"): with self.roundtrip(badenc) as actual: pass @@ -2343,9 +2342,7 @@ def test_chunk_encoding_with_dask(self) -> None: # but itermediate unaligned chunks are bad badenc = ds.chunk({"x": (3, 5, 3, 1)}) badenc.var1.encoding["chunks"] = (3,) - with pytest.raises( - NotImplementedError, match=r"would overlap multiple dask chunks" - ): + with pytest.raises(ValueError, match=r"would overlap multiple dask chunks"): with self.roundtrip(badenc) as actual: pass @@ -2359,7 +2356,7 @@ def test_chunk_encoding_with_dask(self) -> None: # TODO: remove this failure once synchronized overlapping writes are # supported by xarray ds_chunk4["var1"].encoding.update({"chunks": 5}) - with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"): + with pytest.raises(ValueError, match=r"named 'var1' would overlap"): with self.roundtrip(ds_chunk4) as actual: pass # override option @@ -2468,6 +2465,27 @@ def test_group(self) -> None: ) as actual: assert_identical(original, actual) + def test_zarr_mode_w_overwrites_encoding(self) -> None: + import zarr + + data = Dataset({"foo": ("x", [1.0, 1.0, 1.0])}) + with self.create_zarr_target() as store: + data.to_zarr( + store, **self.version_kwargs, encoding={"foo": {"add_offset": 1}} + ) + np.testing.assert_equal( + zarr.open_group(store, **self.version_kwargs)["foo"], data.foo.data - 1 + ) + data.to_zarr( + store, + **self.version_kwargs, + encoding={"foo": {"add_offset": 0}}, + mode="w", + ) + np.testing.assert_equal( + zarr.open_group(store, **self.version_kwargs)["foo"], data.foo.data + ) + def test_encoding_kwarg_fixed_width_string(self) -> None: # not relevant for zarr, since we don't use EncodedStringCoder pass @@ -2605,7 +2623,9 @@ def test_append_with_append_dim_no_overwrite(self) -> None: # overwrite a coordinate; # for mode='a-', this will not get written to the store # because it does not have the append_dim as a dim - ds_to_append.lon.data[:] = -999 + lon = ds_to_append.lon.to_numpy().copy() + lon[:] = -999 + ds_to_append["lon"] = lon ds_to_append.to_zarr( store_target, mode="a-", append_dim="time", **self.version_kwargs ) @@ -2615,7 +2635,9 @@ def test_append_with_append_dim_no_overwrite(self) -> None: # by default, mode="a" will overwrite all coordinates. ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs) actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs) - original2.lon.data[:] = -999 + lon = original2.lon.to_numpy().copy() + lon[:] = -999 + original2["lon"] = lon assert_identical(original2, actual) @requires_dask @@ -3556,6 +3578,16 @@ def test_dump_encodings_h5py(self) -> None: assert actual.x.encoding["compression"] == "lzf" assert actual.x.encoding["compression_opts"] is None + def test_decode_utf8_warning(self) -> None: + title = b"\xc3" + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, "w") as f: + f.title = title + with pytest.warns(UnicodeWarning, match="returning bytes undecoded") as w: + ds = xr.load_dataset(tmp_file, engine="h5netcdf") + assert ds.title == title + assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message) + @requires_h5netcdf @requires_netCDF4 @@ -5637,24 +5669,27 @@ def test_zarr_region_index_write(self, tmp_path): } ) - ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + region_slice = dict(x=slice(2, 4), y=slice(6, 8)) + ds_region = 1 + ds.isel(region_slice) ds.to_zarr(tmp_path / "test.zarr") - with patch.object( - ZarrStore, - "set_variables", - side_effect=ZarrStore.set_variables, - autospec=True, - ) as mock: - ds_region.to_zarr(tmp_path / "test.zarr", region="auto", mode="r+") - - # should write the data vars but never the index vars with auto mode - for call in mock.call_args_list: - written_variables = call.args[1].keys() - assert "test" in written_variables - assert "x" not in written_variables - assert "y" not in written_variables + region: Mapping[str, slice] | Literal["auto"] + for region in [region_slice, "auto"]: # type: ignore + with patch.object( + ZarrStore, + "set_variables", + side_effect=ZarrStore.set_variables, + autospec=True, + ) as mock: + ds_region.to_zarr(tmp_path / "test.zarr", region=region, mode="r+") + + # should write the data vars but never the index vars with auto mode + for call in mock.call_args_list: + written_variables = call.args[1].keys() + assert "test" in written_variables + assert "x" not in written_variables + assert "y" not in written_variables def test_zarr_region_append(self, tmp_path): x = np.arange(0, 50, 10) @@ -5716,3 +5751,80 @@ def test_zarr_region(tmp_path): # Write without region ds_transposed.to_zarr(tmp_path / "test.zarr", mode="r+") + + +@requires_zarr +@requires_dask +def test_zarr_region_chunk_partial(tmp_path): + """ + Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`. + """ + ds = ( + xr.DataArray(np.arange(120).reshape(4, 3, -1), dims=list("abc")) + .rename("var1") + .to_dataset() + ) + + ds.chunk(5).to_zarr(tmp_path / "foo.zarr", compute=False, mode="w") + with pytest.raises(ValueError): + for r in range(ds.sizes["a"]): + ds.chunk(3).isel(a=[r]).to_zarr( + tmp_path / "foo.zarr", region=dict(a=slice(r, r + 1)) + ) + + +@requires_zarr +@requires_dask +def test_zarr_append_chunk_partial(tmp_path): + t_coords = np.array([np.datetime64("2020-01-01").astype("datetime64[ns]")]) + data = np.ones((10, 10)) + + da = xr.DataArray( + data.reshape((-1, 10, 10)), + dims=["time", "x", "y"], + coords={"time": t_coords}, + name="foo", + ) + da.to_zarr(tmp_path / "foo.zarr", mode="w", encoding={"foo": {"chunks": (5, 5, 1)}}) + + new_time = np.array([np.datetime64("2021-01-01").astype("datetime64[ns]")]) + + da2 = xr.DataArray( + data.reshape((-1, 10, 10)), + dims=["time", "x", "y"], + coords={"time": new_time}, + name="foo", + ) + with pytest.raises(ValueError, match="encoding was provided"): + da2.to_zarr( + tmp_path / "foo.zarr", + append_dim="time", + mode="a", + encoding={"foo": {"chunks": (1, 1, 1)}}, + ) + + # chunking with dask sidesteps the encoding check, so we need a different check + with pytest.raises(ValueError, match="Specified zarr chunks"): + da2.chunk({"x": 1, "y": 1, "time": 1}).to_zarr( + tmp_path / "foo.zarr", append_dim="time", mode="a" + ) + + +@requires_zarr +@requires_dask +def test_zarr_region_chunk_partial_offset(tmp_path): + # https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545 + store = tmp_path / "foo.zarr" + data = np.ones((30,)) + da = xr.DataArray(data, dims=["x"], coords={"x": range(30)}, name="foo").chunk(x=10) + da.to_zarr(store, compute=False) + + da.isel(x=slice(10)).chunk(x=(10,)).to_zarr(store, region="auto") + + da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr( + store, safe_chunks=False, region="auto" + ) + + # This write is unsafe, and should raise an error, but does not. + # with pytest.raises(ValueError): + # da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") diff --git a/xarray/tests/test_coordinates.py b/xarray/tests/test_coordinates.py index 68ce55b05da..f88e554d333 100644 --- a/xarray/tests/test_coordinates.py +++ b/xarray/tests/test_coordinates.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numpy as np import pandas as pd import pytest @@ -8,6 +9,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.core.variable import IndexVariable, Variable from xarray.tests import assert_identical, source_ndarray @@ -23,10 +25,12 @@ def test_init_default_index(self) -> None: assert_identical(coords.to_dataset(), expected) assert "x" in coords.xindexes + @pytest.mark.filterwarnings("error:IndexVariable") def test_init_no_default_index(self) -> None: # dimension coordinate with no default index (explicit) coords = Coordinates(coords={"x": [1, 2]}, indexes={}) assert "x" not in coords.xindexes + assert not isinstance(coords["x"], IndexVariable) def test_init_from_coords(self) -> None: expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) @@ -171,3 +175,10 @@ def test_align(self) -> None: left2, right2 = align(left, right, join="override") assert_identical(left2, left) assert_identical(left2, right2) + + def test_dataset_from_coords_with_multidim_var_same_name(self): + # regression test for GH #8883 + var = Variable(data=np.arange(6).reshape(2, 3), dims=["x", "y"]) + coords = Coordinates(coords={"x": var}, indexes={}) + ds = Dataset(coords=coords) + assert ds.coords["x"].dims == ("x", "y") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 4937fc5f3a3..e2a64964775 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -51,6 +51,7 @@ assert_equal, assert_identical, assert_no_warnings, + assert_writeable, create_test_data, has_cftime, has_dask, @@ -79,6 +80,13 @@ except ImportError: pass +# from numpy version 2.0 trapz is deprecated and renamed to trapezoid +# remove once numpy 2.0 is the oldest supported version +try: + from numpy import trapezoid # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import trapz as trapezoid + sparse_array_type = array_type("sparse") pytestmark = [ @@ -96,11 +104,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: nt2 = 2 time1 = pd.date_range("2000-01-01", periods=nt1) time2 = pd.date_range("2000-02-01", periods=nt2) - string_var = np.array(["ae", "bc", "df"], dtype=object) + string_var = np.array(["a", "bc", "def"], dtype=object) string_var_to_append = np.array(["asdf", "asdfg"], dtype=object) string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2") string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2") - unicode_var = ["áó", "áó", "áó"] + unicode_var = np.array(["áó", "áó", "áó"]) datetime_var = np.array( ["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]" ) @@ -119,17 +127,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: coords=[lat, lon, time1], dims=["lat", "lon", "time"], ), - "string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]), - "string_var_fixed_length": xr.DataArray( - string_var_fixed_length, coords=[time1], dims=["time"] - ), - "unicode_var": xr.DataArray( - unicode_var, coords=[time1], dims=["time"] - ).astype(np.str_), - "datetime_var": xr.DataArray( - datetime_var, coords=[time1], dims=["time"] - ), - "bool_var": xr.DataArray(bool_var, coords=[time1], dims=["time"]), + "string_var": ("time", string_var), + "string_var_fixed_length": ("time", string_var_fixed_length), + "unicode_var": ("time", unicode_var), + "datetime_var": ("time", datetime_var), + "bool_var": ("time", bool_var), } ) @@ -140,21 +142,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: coords=[lat, lon, time2], dims=["lat", "lon", "time"], ), - "string_var": xr.DataArray( - string_var_to_append, coords=[time2], dims=["time"] - ), - "string_var_fixed_length": xr.DataArray( - string_var_fixed_length_to_append, coords=[time2], dims=["time"] - ), - "unicode_var": xr.DataArray( - unicode_var[:nt2], coords=[time2], dims=["time"] - ).astype(np.str_), - "datetime_var": xr.DataArray( - datetime_var_to_append, coords=[time2], dims=["time"] - ), - "bool_var": xr.DataArray( - bool_var_to_append, coords=[time2], dims=["time"] - ), + "string_var": ("time", string_var_to_append), + "string_var_fixed_length": ("time", string_var_fixed_length_to_append), + "unicode_var": ("time", unicode_var[:nt2]), + "datetime_var": ("time", datetime_var_to_append), + "bool_var": ("time", bool_var_to_append), } ) @@ -168,8 +160,9 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: } ) - assert all(objp.data.flags.writeable for objp in ds.variables.values()) - assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values()) + assert_writeable(ds) + assert_writeable(ds_to_append) + assert_writeable(ds_with_new_var) return ds, ds_to_append, ds_with_new_var @@ -182,10 +175,8 @@ def make_datasets(data, data_to_append) -> tuple[Dataset, Dataset]: ds_to_append = xr.Dataset( {"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]} ) - assert all(objp.data.flags.writeable for objp in ds.variables.values()) - assert all( - objp.data.flags.writeable for objp in ds_to_append.variables.values() - ) + assert_writeable(ds) + assert_writeable(ds_to_append) return ds, ds_to_append u2_strings = ["ab", "cd", "ef"] @@ -495,14 +486,6 @@ def test_constructor(self) -> None: actual = Dataset({"z": expected["z"]}) assert_identical(expected, actual) - def test_constructor_invalid_dims(self) -> None: - # regression for GH1120 - with pytest.raises(MergeError): - Dataset( - data_vars=dict(v=("y", [1, 2, 3, 4])), - coords=dict(y=DataArray([0.1, 0.2, 0.3, 0.4], dims="x")), - ) - def test_constructor_1d(self) -> None: expected = Dataset({"x": (["x"], 5.0 + np.arange(5))}) actual = Dataset({"x": 5.0 + np.arange(5)}) @@ -2964,10 +2947,11 @@ def test_copy_coords(self, deep, expected_orig) -> None: name="value", ).to_dataset() ds_cp = ds.copy(deep=deep) - ds_cp.coords["a"].data[0] = 999 + new_a = np.array([999, 2]) + ds_cp.coords["a"] = ds_cp.a.copy(data=new_a) expected_cp = xr.DataArray( - xr.IndexVariable("a", np.array([999, 2])), + xr.IndexVariable("a", new_a), coords={"a": [999, 2]}, dims=["a"], ) @@ -7014,7 +6998,7 @@ def test_integrate(dask) -> None: actual = da.integrate("x") # coordinate that contains x should be dropped. expected_x = xr.DataArray( - np.trapz(da.compute(), da["x"], axis=0), + trapezoid(da.compute(), da["x"], axis=0), dims=["y"], coords={k: v for k, v in da.coords.items() if "x" not in v.dims}, ) @@ -7027,7 +7011,7 @@ def test_integrate(dask) -> None: # along y actual = da.integrate("y") expected_y = xr.DataArray( - np.trapz(da, da["y"], axis=1), + trapezoid(da, da["y"], axis=1), dims=["x"], coords={k: v for k, v in da.coords.items() if "y" not in v.dims}, ) @@ -7067,12 +7051,10 @@ def test_cumulative_integrate(dask) -> None: # along x actual = da.cumulative_integrate("x") - # From scipy-1.6.0 cumtrapz is renamed to cumulative_trapezoid, but cumtrapz is - # still provided for backward compatibility - from scipy.integrate import cumtrapz + from scipy.integrate import cumulative_trapezoid expected_x = xr.DataArray( - cumtrapz(da.compute(), da["x"], axis=0, initial=0.0), + cumulative_trapezoid(da.compute(), da["x"], axis=0, initial=0.0), dims=["x", "y"], coords=da.coords, ) @@ -7088,7 +7070,7 @@ def test_cumulative_integrate(dask) -> None: # along y actual = da.cumulative_integrate("y") expected_y = xr.DataArray( - cumtrapz(da, da["y"], axis=1, initial=0.0), + cumulative_trapezoid(da, da["y"], axis=1, initial=0.0), dims=["x", "y"], coords=da.coords, ) @@ -7110,7 +7092,7 @@ def test_cumulative_integrate(dask) -> None: @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) -def test_trapz_datetime(dask, which_datetime) -> None: +def test_trapezoid_datetime(dask, which_datetime) -> None: rs = np.random.RandomState(42) if which_datetime == "np": coord = np.array( @@ -7141,7 +7123,7 @@ def test_trapz_datetime(dask, which_datetime) -> None: da = da.chunk({"time": 4}) actual = da.integrate("time", datetime_unit="D") - expected_data = np.trapz( + expected_data = trapezoid( da.compute().data, duck_array_ops.datetime_to_numeric(da["time"].data, datetime_unit="D"), axis=0, diff --git a/xarray/datatree_/datatree/tests/test_datatree.py b/xarray/tests/test_datatree.py similarity index 80% rename from xarray/datatree_/datatree/tests/test_datatree.py rename to xarray/tests/test_datatree.py index cfb57470651..c7359b3929e 100644 --- a/xarray/datatree_/datatree/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2,29 +2,30 @@ import numpy as np import pytest + import xarray as xr +import xarray.datatree_.datatree.testing as dtt import xarray.testing as xrt +from xarray.core.datatree import DataTree +from xarray.core.treenode import NotFoundInTreeError from xarray.tests import create_test_data, source_ndarray -import xarray.datatree_.datatree.testing as dtt -from xarray.datatree_.datatree import DataTree, NotFoundInTreeError - class TestTreeCreation: def test_empty(self): - dt = DataTree(name="root") + dt: DataTree = DataTree(name="root") assert dt.name == "root" assert dt.parent is None assert dt.children == {} xrt.assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): - dt = DataTree() + dt: DataTree = DataTree() assert dt.name is None def test_bad_names(self): with pytest.raises(TypeError): - DataTree(name=5) + DataTree(name=5) # type: ignore[arg-type] with pytest.raises(ValueError): DataTree(name="folder/data") @@ -32,7 +33,7 @@ def test_bad_names(self): class TestFamilyTree: def test_setparent_unnamed_child_node_fails(self): - john = DataTree(name="john") + john: DataTree = DataTree(name="john") with pytest.raises(ValueError, match="unnamed"): DataTree(parent=john) @@ -40,8 +41,8 @@ def test_create_two_children(self): root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) set1_data = xr.Dataset({"a": 0, "b": 1}) - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) DataTree(name="set1", parent=root) DataTree(name="set2", parent=set1) @@ -50,11 +51,11 @@ def test_create_full_tree(self, simple_datatree): set1_data = xr.Dataset({"a": 0, "b": 1}) set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) DataTree(name="set1", parent=set1) DataTree(name="set2", parent=set1) - set2 = DataTree(name="set2", parent=root, data=set2_data) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) DataTree(name="set1", parent=set2) DataTree(name="set3", parent=root) @@ -64,36 +65,36 @@ def test_create_full_tree(self, simple_datatree): class TestNames: def test_child_gets_named_on_attach(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) # noqa + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) # noqa assert sue.name == "Sue" class TestPaths: def test_path_property(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) - john = DataTree(children={"Mary": mary}) # noqa + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + john: DataTree = DataTree(children={"Mary": mary}) assert sue.path == "/Mary/Sue" assert john.path == "/" def test_path_roundtrip(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) - john = DataTree(children={"Mary": mary}) # noqa + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + john: DataTree = DataTree(children={"Mary": mary}) assert john[sue.path] is sue def test_same_tree(self): - mary = DataTree() - kate = DataTree() - john = DataTree(children={"Mary": mary, "Kate": kate}) # noqa + mary: DataTree = DataTree() + kate: DataTree = DataTree() + john: DataTree = DataTree(children={"Mary": mary, "Kate": kate}) # noqa assert mary.same_tree(kate) def test_relative_paths(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) - annie = DataTree() - john = DataTree(children={"Mary": mary, "Annie": annie}) + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + annie: DataTree = DataTree() + john: DataTree = DataTree(children={"Mary": mary, "Annie": annie}) result = sue.relative_to(john) assert result == "Mary/Sue" @@ -102,7 +103,7 @@ def test_relative_paths(self): assert sue.relative_to(annie) == "../Mary/Sue" assert sue.relative_to(sue) == "." - evil_kate = DataTree() + evil_kate: DataTree = DataTree() with pytest.raises( NotFoundInTreeError, match="nodes do not lie within the same tree" ): @@ -112,116 +113,117 @@ def test_relative_paths(self): class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) - john = DataTree(name="john", data=dat) + john: DataTree = DataTree(name="john", data=dat) xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): - DataTree(name="mary", parent=john, data="junk") # noqa + DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] def test_set_data(self): - john = DataTree(name="john") + john: DataTree = DataTree(name="john") dat = xr.Dataset({"a": 0}) - john.ds = dat + john.ds = dat # type: ignore[assignment] xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): - john.ds = "junk" + john.ds = "junk" # type: ignore[assignment] def test_has_data(self): - john = DataTree(name="john", data=xr.Dataset({"a": 0})) + john: DataTree = DataTree(name="john", data=xr.Dataset({"a": 0})) assert john.has_data - john = DataTree(name="john", data=None) - assert not john.has_data + john_no_data: DataTree = DataTree(name="john", data=None) + assert not john_no_data.has_data def test_is_hollow(self): - john = DataTree(data=xr.Dataset({"a": 0})) + john: DataTree = DataTree(data=xr.Dataset({"a": 0})) assert john.is_hollow - eve = DataTree(children={"john": john}) + eve: DataTree = DataTree(children={"john": john}) assert eve.is_hollow - eve.ds = xr.Dataset({"a": 1}) + eve.ds = xr.Dataset({"a": 1}) # type: ignore[assignment] assert not eve.is_hollow class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): - dt = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) with pytest.raises(KeyError, match="already contains a data variable named a"): DataTree(name="a", data=None, parent=dt) def test_assign_when_already_child_with_variables_name(self): - dt = DataTree(data=None) + dt: DataTree = DataTree(data=None) DataTree(name="a", data=None, parent=dt) with pytest.raises(KeyError, match="names would collide"): - dt.ds = xr.Dataset({"a": 0}) + dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] - dt.ds = xr.Dataset() + dt.ds = xr.Dataset() # type: ignore[assignment] new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) with pytest.raises(KeyError, match="names would collide"): - dt.ds = new_ds + dt.ds = new_ds # type: ignore[assignment] -class TestGet: - ... +class TestGet: ... class TestGetItem: def test_getitem_node(self): - folder1 = DataTree(name="folder1") - results = DataTree(name="results", parent=folder1) - highres = DataTree(name="highres", parent=results) + folder1: DataTree = DataTree(name="folder1") + results: DataTree = DataTree(name="results", parent=folder1) + highres: DataTree = DataTree(name="highres", parent=results) assert folder1["results"] is results assert folder1["results/highres"] is highres def test_getitem_self(self): - dt = DataTree() + dt: DataTree = DataTree() assert dt["."] is dt def test_getitem_single_data_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=data) + results: DataTree = DataTree(name="results", data=data) xrt.assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1 = DataTree(name="folder1") - results = DataTree(name="results", parent=folder1) + folder1: DataTree = DataTree(name="folder1") + results: DataTree = DataTree(name="results", parent=folder1) DataTree(name="highres", parent=results, data=data) xrt.assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") DataTree(name="results", parent=folder1) with pytest.raises(KeyError): folder1["results/highres"] def test_getitem_nonexistent_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=data) + results: DataTree = DataTree(name="results", data=data) with pytest.raises(KeyError): results["pressure"] @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") def test_getitem_multiple_data_variables(self): data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) - results = DataTree(name="results", data=data) - xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) + results: DataTree = DataTree(name="results", data=data) + xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] - @pytest.mark.xfail(reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)") + @pytest.mark.xfail( + reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" + ) def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=data) - xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) + results: DataTree = DataTree(name="results", data=data) + xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: def test_update(self): - dt = DataTree() + dt: DataTree = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree()}) expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) print(dt) @@ -233,13 +235,13 @@ def test_update(self): def test_update_new_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") xrt.assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self): - dt = DataTree() + dt: DataTree = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) assert "a" in dt.children child = dt["a"] @@ -336,8 +338,8 @@ def test_copy_with_data(self, create_test_datatree): class TestSetItem: def test_setitem_new_child_node(self): - john = DataTree(name="john") - mary = DataTree(name="mary") + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary") john["mary"] = mary grafted_mary = john["mary"] @@ -345,14 +347,14 @@ def test_setitem_new_child_node(self): assert grafted_mary.name == "mary" def test_setitem_unnamed_child_node_becomes_named(self): - john2 = DataTree(name="john2") + john2: DataTree = DataTree(name="john2") john2["sonny"] = DataTree() assert john2["sonny"].name == "sonny" def test_setitem_new_grandchild_node(self): - john = DataTree(name="john") - mary = DataTree(name="mary", parent=john) - rose = DataTree(name="rose") + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary", parent=john) + rose: DataTree = DataTree(name="rose") john["mary/rose"] = rose grafted_rose = john["mary/rose"] @@ -360,98 +362,97 @@ def test_setitem_new_grandchild_node(self): assert grafted_rose.name == "rose" def test_grafted_subtree_retains_name(self): - subtree = DataTree(name="original_subtree_name") - root = DataTree(name="root") + subtree: DataTree = DataTree(name="original_subtree_name") + root: DataTree = DataTree(name="root") root["new_subtree_name"] = subtree # noqa assert subtree.name == "original_subtree_name" def test_setitem_new_empty_node(self): - john = DataTree(name="john") + john: DataTree = DataTree(name="john") john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) xrt.assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): - john = DataTree(name="john") - mary = DataTree(name="mary", parent=john, data=xr.Dataset()) + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) john["mary"] = DataTree() xrt.assert_identical(mary.to_dataset(), xr.Dataset()) - john.ds = xr.Dataset() + john.ds = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): john["."] = DataTree() @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results") + results: DataTree = DataTree(name="results") results["."] = data xrt.assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = data xrt.assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results/highres"] = data xrt.assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") xrt.assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): data = xr.DataArray([0, 50]) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = data xrt.assert_equal(folder1["results"], data) def test_setitem_variable(self): var = xr.Variable(data=[0, 50], dims="x") - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = var xrt.assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self): - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = 0 xrt.assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self): - results = DataTree(name="results") + results: DataTree = DataTree(name="results") results["pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results.ds results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) assert "temp" in results.ds # What if there is a path to traverse first? - results = DataTree(name="results") - results["highres/pressure"] = xr.DataArray(data=[2, 3]) - assert "pressure" in results["highres"].ds - results["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) - assert "temp" in results["highres"].ds + results_with_path: DataTree = DataTree(name="results") + results_with_path["highres/pressure"] = xr.DataArray(data=[2, 3]) + assert "pressure" in results_with_path["highres"].ds + results_with_path["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) + assert "temp" in results_with_path["highres"].ds def test_setitem_dataarray_replace_existing_node(self): t = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=t) + results: DataTree = DataTree(name="results", data=t) p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) xrt.assert_identical(results.to_dataset(), expected) -class TestDictionaryInterface: - ... +class TestDictionaryInterface: ... class TestTreeFromDict: @@ -501,8 +502,8 @@ def test_full(self, simple_datatree): ] def test_datatree_values(self): - dat1 = DataTree(data=xr.Dataset({"a": 1})) - expected = DataTree() + dat1: DataTree = DataTree(data=xr.Dataset({"a": 1})) + expected: DataTree = DataTree() expected["a"] = dat1 actual = DataTree.from_dict({"a": dat1}) @@ -527,7 +528,7 @@ def test_roundtrip_unnamed_root(self, simple_datatree): class TestDatasetView: def test_view_contents(self): ds = create_test_data() - dt = DataTree(data=ds) + dt: DataTree = DataTree(data=ds) assert ds.identical( dt.ds ) # this only works because Dataset.identical doesn't check types @@ -535,7 +536,7 @@ def test_view_contents(self): def test_immutability(self): # See issue https://github.com/xarray-contrib/datatree/issues/38 - dt = DataTree(name="root", data=None) + dt: DataTree = DataTree(name="root", data=None) DataTree(name="a", data=None, parent=dt) with pytest.raises( @@ -553,7 +554,7 @@ def test_immutability(self): def test_methods(self): ds = create_test_data() - dt = DataTree(data=ds) + dt: DataTree = DataTree(data=ds) assert ds.mean().identical(dt.ds.mean()) assert type(dt.ds.mean()) == xr.Dataset @@ -572,7 +573,7 @@ def test_init_via_type(self): dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(3, 4))}, ).to_dataset(name="data") - dt = DataTree(data=a) + dt: DataTree = DataTree(data=a) def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) @@ -643,7 +644,7 @@ def test_drop_nodes(self): assert childless.children == {} def test_assign(self): - dt = DataTree() + dt: DataTree = DataTree() expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) # kwargs form @@ -727,5 +728,5 @@ def test_filter(self): }, name="Abe", ) - elders = simpsons.filter(lambda node: node["age"] > 18) + elders = simpsons.filter(lambda node: node["age"].item() > 18) dtt.assert_identical(elders, expected) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index d927550e424..df9c40ca6f4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -13,10 +13,10 @@ import xarray as xr from xarray import DataArray, Dataset, Variable from xarray.core.groupby import _consolidate_slices +from xarray.core.types import InterpOptions from xarray.tests import ( InaccessibleArray, assert_allclose, - assert_array_equal, assert_equal, assert_identical, create_test_data, @@ -30,7 +30,7 @@ @pytest.fixture -def dataset(): +def dataset() -> xr.Dataset: ds = xr.Dataset( { "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), @@ -44,7 +44,7 @@ def dataset(): @pytest.fixture -def array(dataset): +def array(dataset) -> xr.DataArray: return dataset["foo"] @@ -245,6 +245,51 @@ def test_da_groupby_empty() -> None: empty_array.groupby("dim") +@requires_dask +def test_dask_da_groupby_quantile() -> None: + # Only works when the grouped reduction can run blockwise + # Scalar quantile + expected = xr.DataArray( + data=[2, 5], coords={"x": [1, 2], "quantile": 0.5}, dims="x" + ) + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + with pytest.raises(ValueError): + array.chunk(x=1).groupby("x").quantile(0.5) + + # will work blockwise with flox + actual = array.chunk(x=3).groupby("x").quantile(0.5) + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=-1).groupby("x").quantile(0.5) + assert_identical(expected, actual) + + +@requires_dask +def test_dask_da_groupby_median() -> None: + expected = xr.DataArray(data=[2, 5], coords={"x": [1, 2]}, dims="x") + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + with xr.set_options(use_flox=False): + actual = array.chunk(x=1).groupby("x").median() + assert_identical(expected, actual) + + with xr.set_options(use_flox=True): + actual = array.chunk(x=1).groupby("x").median() + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=3).groupby("x").median() + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=-1).groupby("x").median() + assert_identical(expected, actual) + + def test_da_groupby_quantile() -> None: array = xr.DataArray( data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" @@ -478,7 +523,7 @@ def test_ds_groupby_quantile() -> None: @pytest.mark.parametrize("as_dataset", [False, True]) -def test_groupby_quantile_interpolation_deprecated(as_dataset) -> None: +def test_groupby_quantile_interpolation_deprecated(as_dataset: bool) -> None: array = xr.DataArray(data=[1, 2, 3, 4], coords={"x": [1, 1, 2, 2]}, dims="x") arr: xr.DataArray | xr.Dataset @@ -849,7 +894,7 @@ def test_groupby_dataset_reduce() -> None: @pytest.mark.parametrize("squeeze", [True, False]) -def test_groupby_dataset_math(squeeze) -> None: +def test_groupby_dataset_math(squeeze: bool) -> None: def reorder_dims(x): return x.transpose("dim1", "dim2", "dim3", "time") @@ -1070,7 +1115,7 @@ def test_groupby_dataset_order() -> None: # .assertEqual(all_vars, all_vars_ref) -def test_groupby_dataset_fillna(): +def test_groupby_dataset_fillna() -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) expected = Dataset({"a": ("x", range(4))}, {"x": [0, 1, 2, 3]}) for target in [ds, expected]: @@ -1090,12 +1135,12 @@ def test_groupby_dataset_fillna(): assert actual.a.attrs == ds.a.attrs -def test_groupby_dataset_where(): +def test_groupby_dataset_where() -> None: # groupby ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) cond = Dataset({"a": ("c", [True, False])}) expected = ds.copy(deep=True) - expected["a"].values = [0, 1] + [np.nan] * 3 + expected["a"].values = np.array([0, 1] + [np.nan] * 3) actual = ds.groupby("c").where(cond) assert_identical(expected, actual) @@ -1108,7 +1153,7 @@ def test_groupby_dataset_where(): assert actual.a.attrs == ds.a.attrs -def test_groupby_dataset_assign(): +def test_groupby_dataset_assign() -> None: ds = Dataset({"a": ("x", range(3))}, {"b": ("x", ["A"] * 2 + ["B"])}) actual = ds.groupby("b").assign(c=lambda ds: 2 * ds.a) expected = ds.merge({"c": ("x", [0, 2, 4])}) @@ -1123,7 +1168,7 @@ def test_groupby_dataset_assign(): assert_identical(actual, expected) -def test_groupby_dataset_map_dataarray_func(): +def test_groupby_dataset_map_dataarray_func() -> None: # regression GH6379 ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, coords={"x": [0, 0, 1, 1]}) actual = ds.groupby("x").map(lambda grp: grp.foo.mean()) @@ -1131,7 +1176,7 @@ def test_groupby_dataset_map_dataarray_func(): assert_identical(actual, expected) -def test_groupby_dataarray_map_dataset_func(): +def test_groupby_dataarray_map_dataset_func() -> None: # regression GH6379 da = DataArray([1, 2, 3, 4], coords={"x": [0, 0, 1, 1]}, dims="x", name="foo") actual = da.groupby("x").map(lambda grp: grp.mean().to_dataset()) @@ -1141,7 +1186,7 @@ def test_groupby_dataarray_map_dataset_func(): @requires_flox @pytest.mark.parametrize("kwargs", [{"method": "map-reduce"}, {"engine": "numpy"}]) -def test_groupby_flox_kwargs(kwargs): +def test_groupby_flox_kwargs(kwargs) -> None: ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) with xr.set_options(use_flox=False): expected = ds.groupby("c").mean() @@ -1152,7 +1197,7 @@ def test_groupby_flox_kwargs(kwargs): class TestDataArrayGroupBy: @pytest.fixture(autouse=True) - def setup(self): + def setup(self) -> None: self.attrs = {"attr1": "value1", "attr2": 2929} self.x = np.random.random((10, 20)) self.v = Variable(["x", "y"], self.x) @@ -1169,7 +1214,7 @@ def setup(self): self.da.coords["abc"] = ("y", np.array(["a"] * 9 + ["c"] + ["b"] * 10)) self.da.coords["y"] = 20 + 100 * self.da["y"] - def test_stack_groupby_unsorted_coord(self): + def test_stack_groupby_unsorted_coord(self) -> None: data = [[0, 1], [2, 3]] data_flat = [0, 1, 2, 3] dims = ["x", "y"] @@ -1188,7 +1233,7 @@ def test_stack_groupby_unsorted_coord(self): expected2 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx2}) assert_equal(actual2, expected2) - def test_groupby_iter(self): + def test_groupby_iter(self) -> None: for (act_x, act_dv), (exp_x, exp_ds) in zip( self.dv.groupby("y", squeeze=False), self.ds.groupby("y", squeeze=False) ): @@ -1200,12 +1245,19 @@ def test_groupby_iter(self): ): assert_identical(exp_dv, act_dv) - def test_groupby_properties(self): + def test_groupby_properties(self) -> None: grouped = self.da.groupby("abc") expected_groups = {"a": range(0, 9), "c": [9], "b": range(10, 20)} assert expected_groups.keys() == grouped.groups.keys() for key in expected_groups: - assert_array_equal(expected_groups[key], grouped.groups[key]) + expected_group = expected_groups[key] + actual_group = grouped.groups[key] + + # TODO: array_api doesn't allow slice: + assert not isinstance(expected_group, slice) + assert not isinstance(actual_group, slice) + + np.testing.assert_array_equal(expected_group, actual_group) assert 3 == len(grouped) @pytest.mark.parametrize( @@ -1229,7 +1281,7 @@ def identity(x): if (by.name if use_da else by) != "abc": assert len(recwarn) == (1 if squeeze in [None, True] else 0) - def test_groupby_sum(self): + def test_groupby_sum(self) -> None: array = self.da grouped = array.groupby("abc") @@ -1283,7 +1335,7 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.sum("y")) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method): + def test_groupby_reductions(self, method) -> None: array = self.da grouped = array.groupby("abc") @@ -1313,7 +1365,7 @@ def test_groupby_reductions(self, method): assert_allclose(expected, actual_legacy) assert_allclose(expected, actual_npg) - def test_groupby_count(self): + def test_groupby_count(self) -> None: array = DataArray( [0, 0, np.nan, np.nan, 0, 0], coords={"cat": ("x", ["a", "b", "b", "c", "c", "c"])}, @@ -1325,7 +1377,9 @@ def test_groupby_count(self): @pytest.mark.parametrize("shortcut", [True, False]) @pytest.mark.parametrize("keep_attrs", [None, True, False]) - def test_groupby_reduce_keep_attrs(self, shortcut, keep_attrs): + def test_groupby_reduce_keep_attrs( + self, shortcut: bool, keep_attrs: bool | None + ) -> None: array = self.da array.attrs["foo"] = "bar" @@ -1337,7 +1391,7 @@ def test_groupby_reduce_keep_attrs(self, shortcut, keep_attrs): assert_identical(expected, actual) @pytest.mark.parametrize("keep_attrs", [None, True, False]) - def test_groupby_keep_attrs(self, keep_attrs): + def test_groupby_keep_attrs(self, keep_attrs: bool | None) -> None: array = self.da array.attrs["foo"] = "bar" @@ -1351,7 +1405,7 @@ def test_groupby_keep_attrs(self, keep_attrs): actual.data = expected.data assert_identical(expected, actual) - def test_groupby_map_center(self): + def test_groupby_map_center(self) -> None: def center(x): return x - np.mean(x) @@ -1366,14 +1420,14 @@ def center(x): expected_centered = expected_ds["foo"] assert_allclose(expected_centered, grouped.map(center)) - def test_groupby_map_ndarray(self): + def test_groupby_map_ndarray(self) -> None: # regression test for #326 array = self.da grouped = array.groupby("abc") - actual = grouped.map(np.asarray) + actual = grouped.map(np.asarray) # type: ignore[arg-type] # TODO: Not sure using np.asarray like this makes sense with array api assert_equal(array, actual) - def test_groupby_map_changes_metadata(self): + def test_groupby_map_changes_metadata(self) -> None: def change_metadata(x): x.coords["x"] = x.coords["x"] * 2 x.attrs["fruit"] = "lemon" @@ -1387,7 +1441,7 @@ def change_metadata(x): assert_equal(expected, actual) @pytest.mark.parametrize("squeeze", [True, False]) - def test_groupby_math_squeeze(self, squeeze): + def test_groupby_math_squeeze(self, squeeze: bool) -> None: array = self.da grouped = array.groupby("x", squeeze=squeeze) @@ -1406,7 +1460,7 @@ def test_groupby_math_squeeze(self, squeeze): actual = ds + grouped assert_identical(expected, actual) - def test_groupby_math(self): + def test_groupby_math(self) -> None: array = self.da grouped = array.groupby("abc") expected_agg = (grouped.mean(...) - np.arange(3)).rename(None) @@ -1415,13 +1469,13 @@ def test_groupby_math(self): assert_allclose(expected_agg, actual_agg) with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + 1 + grouped + 1 # type: ignore[type-var] with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped + grouped + grouped # type: ignore[type-var] with pytest.raises(TypeError, match=r"in-place operations"): - array += grouped + array += grouped # type: ignore[arg-type] - def test_groupby_math_not_aligned(self): + def test_groupby_math_not_aligned(self) -> None: array = DataArray( range(4), {"b": ("x", [0, 0, 1, 1]), "x": [0, 1, 2, 3]}, dims="x" ) @@ -1442,12 +1496,12 @@ def test_groupby_math_not_aligned(self): expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2) assert_identical(expected, actual) - other = Dataset({"a": ("b", [10])}, {"b": [0]}) - actual = array.groupby("b") + other - expected = Dataset({"a": ("x", [10, 11, np.nan, np.nan])}, array.coords) - assert_identical(expected, actual) + other_ds = Dataset({"a": ("b", [10])}, {"b": [0]}) + actual_ds = array.groupby("b") + other_ds + expected_ds = Dataset({"a": ("x", [10, 11, np.nan, np.nan])}, array.coords) + assert_identical(expected_ds, actual_ds) - def test_groupby_restore_dim_order(self): + def test_groupby_restore_dim_order(self) -> None: array = DataArray( np.random.randn(5, 3), coords={"a": ("x", range(5)), "b": ("y", range(3))}, @@ -1462,7 +1516,7 @@ def test_groupby_restore_dim_order(self): result = array.groupby(by, squeeze=False).map(lambda x: x.squeeze()) assert result.dims == expected_dims - def test_groupby_restore_coord_dims(self): + def test_groupby_restore_coord_dims(self) -> None: array = DataArray( np.random.randn(5, 3), coords={ @@ -1484,7 +1538,7 @@ def test_groupby_restore_coord_dims(self): )["c"] assert result.dims == expected_dims - def test_groupby_first_and_last(self): + def test_groupby_first_and_last(self) -> None: array = DataArray([1, 2, 3, 4, 5], dims="x") by = DataArray(["a"] * 2 + ["b"] * 3, dims="x", name="ab") @@ -1505,7 +1559,7 @@ def test_groupby_first_and_last(self): expected = array # should be a no-op assert_identical(expected, actual) - def make_groupby_multidim_example_array(self): + def make_groupby_multidim_example_array(self) -> DataArray: return DataArray( [[[0, 1], [2, 3]], [[5, 10], [15, 20]]], coords={ @@ -1515,7 +1569,7 @@ def make_groupby_multidim_example_array(self): dims=["time", "ny", "nx"], ) - def test_groupby_multidim(self): + def test_groupby_multidim(self) -> None: array = self.make_groupby_multidim_example_array() for dim, expected_sum in [ ("lon", DataArray([5, 28, 23], coords=[("lon", [30.0, 40.0, 50.0])])), @@ -1524,7 +1578,7 @@ def test_groupby_multidim(self): actual_sum = array.groupby(dim).sum(...) assert_identical(expected_sum, actual_sum) - def test_groupby_multidim_map(self): + def test_groupby_multidim_map(self) -> None: array = self.make_groupby_multidim_example_array() actual = array.groupby("lon").map(lambda x: x - x.mean()) expected = DataArray( @@ -1585,7 +1639,7 @@ def test_groupby_bins( # make sure original array dims are unchanged assert len(array.dim_0) == 4 - def test_groupby_bins_ellipsis(self): + def test_groupby_bins_ellipsis(self) -> None: da = xr.DataArray(np.ones((2, 3, 4))) bins = [-1, 0, 1, 2] with xr.set_options(use_flox=False): @@ -1624,7 +1678,7 @@ def test_groupby_bins_gives_correct_subset(self, use_flox: bool) -> None: actual = gb.count() assert_identical(actual, expected) - def test_groupby_bins_empty(self): + def test_groupby_bins_empty(self) -> None: array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty bins = [0, 4, 5] @@ -1636,7 +1690,7 @@ def test_groupby_bins_empty(self): # (was a problem in earlier versions) assert len(array.x) == 4 - def test_groupby_bins_multidim(self): + def test_groupby_bins_multidim(self) -> None: array = self.make_groupby_multidim_example_array() bins = [0, 15, 20] bin_coords = pd.cut(array["lat"].values.flat, bins).categories @@ -1670,7 +1724,7 @@ def test_groupby_bins_multidim(self): ) assert_identical(actual, expected) - def test_groupby_bins_sort(self): + def test_groupby_bins_sort(self) -> None: data = xr.DataArray( np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} ) @@ -1683,14 +1737,14 @@ def test_groupby_bins_sort(self): expected = data.groupby_bins("x", bins=11).count() assert_identical(actual, expected) - def test_groupby_assign_coords(self): + def test_groupby_assign_coords(self) -> None: array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") actual = array.groupby("c").assign_coords(d=lambda a: a.mean()) expected = array.copy() expected.coords["d"] = ("x", [1.5, 1.5, 3.5, 3.5]) assert_identical(actual, expected) - def test_groupby_fillna(self): + def test_groupby_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") fill_value = DataArray([0, 1], dims="y") actual = a.fillna(fill_value) @@ -1776,7 +1830,7 @@ def test_resample_doctest(self, use_cftime: bool) -> None: ) assert_identical(actual, expected) - def test_da_resample_func_args(self): + def test_da_resample_func_args(self) -> None: def func(arg1, arg2, arg3=0.0): return arg1.mean("time") + arg2 + arg3 @@ -1786,7 +1840,7 @@ def func(arg1, arg2, arg3=0.0): actual = da.resample(time="D").map(func, args=(1.0,), arg3=1.0) assert_identical(actual, expected) - def test_resample_first(self): + def test_resample_first(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) @@ -1823,14 +1877,14 @@ def test_resample_first(self): expected = DataArray(expected_times, [("time", times[::4])], name="time") assert_identical(expected, actual) - def test_resample_bad_resample_dim(self): + def test_resample_bad_resample_dim(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.arange(10), [("__resample_dim__", times)]) with pytest.raises(ValueError, match=r"Proxy resampling dimension"): - array.resample(**{"__resample_dim__": "1D"}).first() + array.resample(**{"__resample_dim__": "1D"}).first() # type: ignore[arg-type] @requires_scipy - def test_resample_drop_nondim_coords(self): + def test_resample_drop_nondim_coords(self) -> None: xs = np.arange(6) ys = np.arange(3) times = pd.date_range("2000-01-01", freq="6h", periods=5) @@ -1861,7 +1915,7 @@ def test_resample_drop_nondim_coords(self): ) assert "tc" not in actual.coords - def test_resample_keep_attrs(self): + def test_resample_keep_attrs(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.ones(10), [("time", times)]) array.attrs["meta"] = "data" @@ -1870,7 +1924,7 @@ def test_resample_keep_attrs(self): expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) assert_identical(result, expected) - def test_resample_skipna(self): + def test_resample_skipna(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.ones(10), [("time", times)]) array[1] = np.nan @@ -1879,7 +1933,7 @@ def test_resample_skipna(self): expected = DataArray([np.nan, 1, 1], [("time", times[::4])]) assert_identical(result, expected) - def test_upsample(self): + def test_upsample(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=5) array = DataArray(np.arange(5), [("time", times)]) @@ -1910,7 +1964,7 @@ def test_upsample(self): expected = DataArray(array.reindex(time=new_times, method="nearest")) assert_identical(expected, actual) - def test_upsample_nd(self): + def test_upsample_nd(self) -> None: # Same as before, but now we try on multi-dimensional DataArrays. xs = np.arange(6) ys = np.arange(3) @@ -1968,29 +2022,29 @@ def test_upsample_nd(self): ) assert_identical(expected, actual) - def test_upsample_tolerance(self): + def test_upsample_tolerance(self) -> None: # Test tolerance keyword for upsample methods bfill, pad, nearest times = pd.date_range("2000-01-01", freq="1D", periods=2) times_upsampled = pd.date_range("2000-01-01", freq="6h", periods=5) array = DataArray(np.arange(2), [("time", times)]) # Forward fill - actual = array.resample(time="6h").ffill(tolerance="12h") + actual = array.resample(time="6h").ffill(tolerance="12h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Backward fill - actual = array.resample(time="6h").bfill(tolerance="12h") + actual = array.resample(time="6h").bfill(tolerance="12h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Nearest - actual = array.resample(time="6h").nearest(tolerance="6h") + actual = array.resample(time="6h").nearest(tolerance="6h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) assert_identical(expected, actual) @requires_scipy - def test_upsample_interpolate(self): + def test_upsample_interpolate(self) -> None: from scipy.interpolate import interp1d xs = np.arange(6) @@ -2005,7 +2059,15 @@ def test_upsample_interpolate(self): # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) - for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + kinds: list[InterpOptions] = [ + "linear", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + ] + for kind in kinds: actual = array.resample(time="1h").interpolate(kind) f = interp1d( np.arange(len(times)), @@ -2028,7 +2090,7 @@ def test_upsample_interpolate(self): @requires_scipy @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") - def test_upsample_interpolate_bug_2197(self): + def test_upsample_interpolate_bug_2197(self) -> None: dates = pd.date_range("2007-02-01", "2007-03-01", freq="D") da = xr.DataArray(np.arange(len(dates)), [("time", dates)]) result = da.resample(time="ME").interpolate("linear") @@ -2039,7 +2101,7 @@ def test_upsample_interpolate_bug_2197(self): assert_equal(result, expected) @requires_scipy - def test_upsample_interpolate_regression_1605(self): + def test_upsample_interpolate_regression_1605(self) -> None: dates = pd.date_range("2016-01-01", "2016-03-31", freq="1D") expected = xr.DataArray( np.random.random((len(dates), 2, 3)), @@ -2052,7 +2114,7 @@ def test_upsample_interpolate_regression_1605(self): @requires_dask @requires_scipy @pytest.mark.parametrize("chunked_time", [True, False]) - def test_upsample_interpolate_dask(self, chunked_time): + def test_upsample_interpolate_dask(self, chunked_time: bool) -> None: from scipy.interpolate import interp1d xs = np.arange(6) @@ -2070,7 +2132,15 @@ def test_upsample_interpolate_dask(self, chunked_time): # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) - for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + kinds: list[InterpOptions] = [ + "linear", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + ] + for kind in kinds: actual = array.chunk(chunks).resample(time="1h").interpolate(kind) actual = actual.compute() f = interp1d( @@ -2159,7 +2229,7 @@ def test_resample_invalid_loffset(self) -> None: class TestDatasetResample: - def test_resample_and_first(self): + def test_resample_and_first(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2185,7 +2255,7 @@ def test_resample_and_first(self): result = actual.reduce(method) assert_equal(expected, result) - def test_resample_min_count(self): + def test_resample_min_count(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2207,7 +2277,7 @@ def test_resample_min_count(self): ) assert_allclose(expected, actual) - def test_resample_by_mean_with_keep_attrs(self): + def test_resample_by_mean_with_keep_attrs(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2227,7 +2297,7 @@ def test_resample_by_mean_with_keep_attrs(self): expected = ds.attrs assert expected == actual - def test_resample_loffset(self): + def test_resample_loffset(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2238,7 +2308,7 @@ def test_resample_loffset(self): ) ds.attrs["dsmeta"] = "dsdata" - def test_resample_by_mean_discarding_attrs(self): + def test_resample_by_mean_discarding_attrs(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2254,7 +2324,7 @@ def test_resample_by_mean_discarding_attrs(self): assert resampled_ds["bar"].attrs == {} assert resampled_ds.attrs == {} - def test_resample_by_last_discarding_attrs(self): + def test_resample_by_last_discarding_attrs(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2271,7 +2341,7 @@ def test_resample_by_last_discarding_attrs(self): assert resampled_ds.attrs == {} @requires_scipy - def test_resample_drop_nondim_coords(self): + def test_resample_drop_nondim_coords(self) -> None: xs = np.arange(6) ys = np.arange(3) times = pd.date_range("2000-01-01", freq="6h", periods=5) @@ -2297,7 +2367,7 @@ def test_resample_drop_nondim_coords(self): actual = ds.resample(time="1h").interpolate("linear") assert "tc" not in actual.coords - def test_resample_old_api(self): + def test_resample_old_api(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { @@ -2308,15 +2378,15 @@ def test_resample_old_api(self): ) with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", "time") + ds.resample("1D", "time") # type: ignore[arg-type] with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", dim="time", how="mean") + ds.resample("1D", dim="time", how="mean") # type: ignore[arg-type] with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", dim="time") + ds.resample("1D", dim="time") # type: ignore[arg-type] - def test_resample_ds_da_are_the_same(self): + def test_resample_ds_da_are_the_same(self) -> None: time = pd.date_range("2000-01-01", freq="6h", periods=365 * 4) ds = xr.Dataset( { @@ -2329,7 +2399,7 @@ def test_resample_ds_da_are_the_same(self): ds.resample(time="ME").mean()["foo"], ds.foo.resample(time="ME").mean() ) - def test_ds_resample_apply_func_args(self): + def test_ds_resample_apply_func_args(self) -> None: def func(arg1, arg2, arg3=0.0): return arg1.mean("time") + arg2 + arg3 @@ -2480,7 +2550,7 @@ def test_min_count_error(use_flox: bool) -> None: @requires_dask -def test_groupby_math_auto_chunk(): +def test_groupby_math_auto_chunk() -> None: da = xr.DataArray( [[1, 2, 3], [1, 2, 3], [1, 2, 3]], dims=("y", "x"), @@ -2494,7 +2564,7 @@ def test_groupby_math_auto_chunk(): @pytest.mark.parametrize("use_flox", [True, False]) -def test_groupby_dim_no_dim_equal(use_flox): +def test_groupby_dim_no_dim_equal(use_flox: bool) -> None: # https://github.com/pydata/xarray/issues/8263 da = DataArray( data=[1, 2, 3, 4], dims="lat", coords={"lat": np.linspace(0, 1.01, 4)} @@ -2506,7 +2576,7 @@ def test_groupby_dim_no_dim_equal(use_flox): @requires_flox -def test_default_flox_method(): +def test_default_flox_method() -> None: import flox.xarray da = xr.DataArray([1, 2, 3], dims="x", coords={"label": ("x", [2, 2, 1])}) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 3ee7f045360..5ebdfd5da6e 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -352,7 +352,7 @@ def test_constructor(self) -> None: # default level names pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data]) index = PandasMultiIndex(pd_idx, "x") - assert index.index.names == ("x_level_0", "x_level_1") + assert list(index.index.names) == ["x_level_0", "x_level_1"] def test_from_variables(self) -> None: v_level1 = xr.Variable( @@ -370,7 +370,7 @@ def test_from_variables(self) -> None: assert index.dim == "x" assert index.index.equals(expected_idx) assert index.index.name == "x" - assert index.index.names == ["level1", "level2"] + assert list(index.index.names) == ["level1", "level2"] var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises( @@ -413,7 +413,8 @@ def test_stack(self) -> None: index = PandasMultiIndex.stack(prod_vars, "z") assert index.dim == "z" - assert index.index.names == ["x", "y"] + # TODO: change to tuple when pandas 3 is minimum + assert list(index.index.names) == ["x", "y"] np.testing.assert_array_equal( index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] ) @@ -531,12 +532,12 @@ def test_rename(self) -> None: assert new_index is index new_index = index.rename({"two": "three"}, {}) - assert new_index.index.names == ["one", "three"] + assert list(new_index.index.names) == ["one", "three"] assert new_index.dim == "x" assert new_index.level_coords_dtype == {"one": " None: def check_indexing(v_eager, v_lazy, indexers): for indexer in indexers: - if isinstance(indexer, indexing.VectorizedIndexer): - actual = v_lazy.vindex[indexer] - expected = v_eager.vindex[indexer] - elif isinstance(indexer, indexing.OuterIndexer): - actual = v_lazy.oindex[indexer] - expected = v_eager.oindex[indexer] - else: - actual = v_lazy[indexer] - expected = v_eager[indexer] + actual = v_lazy[indexer] + expected = v_eager[indexer] assert expected.shape == actual.shape assert isinstance( actual._data, @@ -406,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers): ] check_indexing(v_eager, v_lazy, indexers) + def test_lazily_indexed_array_vindex_setitem(self) -> None: + + lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30)) + + # vectorized indexing + indexer = indexing.VectorizedIndexer( + (np.array([0, 1]), np.array([0, 1]), slice(None, None, None)) + ) + with pytest.raises( + NotImplementedError, + match=r"Lazy item assignment with the vectorized indexer is not yet", + ): + lazy.vindex[indexer] = 0 + + @pytest.mark.parametrize( + "indexer_class, key, value", + [ + (indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10), + (indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10), + ], + ) + def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + lazy = indexing.LazilyIndexedArray(x) + + if indexer_class is indexing.BasicIndexer: + indexer = indexer_class(key) + lazy[indexer] = value + elif indexer_class is indexing.OuterIndexer: + indexer = indexer_class(key) + lazy.oindex[indexer] = value + + assert_array_equal(original[key], value) + class TestCopyOnWriteArray: def test_setitem(self) -> None: @@ -830,7 +880,7 @@ def test_create_mask_dask() -> None: def test_create_mask_error() -> None: with pytest.raises(TypeError, match=r"unexpected key type"): - indexing.create_mask((1, 2), (3, 4)) + indexing.create_mask((1, 2), (3, 4)) # type: ignore[arg-type] @pytest.mark.parametrize( diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index a7644ac9d2b..7151c669fbc 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -833,7 +833,9 @@ def test_interpolate_chunk_1d( dest[dim] = cast( xr.DataArray, - np.linspace(before, after, len(da.coords[dim]) * 13), + np.linspace( + before.item(), after.item(), len(da.coords[dim]) * 13 + ), ) if chunked: dest[dim] = xr.DataArray(data=dest[dim], dims=[dim]) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index f13406d0acc..c1d1058fd6e 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -122,10 +122,13 @@ def test_interpolate_pd_compat(method, fill_value) -> None: # for the numpy linear methods. # see https://github.com/pandas-dev/pandas/issues/55144 # This aligns the pandas output with the xarray output - expected.values[pd.isnull(actual.values)] = np.nan - expected.values[actual.values == fill_value] = fill_value + fixed = expected.values.copy() + fixed[pd.isnull(actual.values)] = np.nan + fixed[actual.values == fill_value] = fill_value + else: + fixed = expected.values - np.testing.assert_allclose(actual.values, expected.values) + np.testing.assert_allclose(actual.values, fixed) @requires_scipy diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 6f983a121fe..7c28b1cd140 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -5,7 +5,7 @@ import math from collections.abc import Hashable from copy import copy -from datetime import datetime +from datetime import date, datetime, timedelta from typing import Any, Callable, Literal import numpy as np @@ -620,6 +620,18 @@ def test_datetime_dimension(self) -> None: ax = plt.gca() assert ax.has_data() + def test_date_dimension(self) -> None: + nrow = 3 + ncol = 4 + start = date(2000, 1, 1) + time = [start + timedelta(days=i) for i in range(nrow)] + a = DataArray( + easy_array((nrow, ncol)), coords=[("time", time), ("y", range(ncol))] + ) + a.plot() + ax = plt.gca() + assert ax.has_data() + @pytest.mark.slow @pytest.mark.filterwarnings("ignore:tight_layout cannot") def test_convenient_facetgrid(self) -> None: @@ -2028,15 +2040,17 @@ def test_normalize_rgb_one_arg_error(self) -> None: for vmin2, vmax2 in ((-1.2, -1), (2, 2.1)): da.plot.imshow(vmin=vmin2, vmax=vmax2) - def test_imshow_rgb_values_in_valid_range(self) -> None: - da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) + @pytest.mark.parametrize("dtype", [np.uint8, np.int8, np.int16]) + def test_imshow_rgb_values_in_valid_range(self, dtype) -> None: + da = DataArray(np.arange(75, dtype=dtype).reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() assert out is not None - dtype = out.dtype - assert dtype is not None - assert dtype == np.uint8 + actual_dtype = out.dtype + assert actual_dtype is not None + assert actual_dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha + assert (out[..., -1] == 255).all() # Compare alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") def test_regression_rgb_imshow_dim_size_one(self) -> None: diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 73f5abe66e5..d9289aa6674 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -64,6 +64,21 @@ def var(): return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5)) +@pytest.mark.parametrize( + "data", + [ + np.array(["a", "bc", "def"], dtype=object), + np.array(["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[ns]"), + ], +) +def test_as_compatible_data_writeable(data): + pd.set_option("mode.copy_on_write", True) + # GH8843, ensure writeable arrays for data_vars even with + # pandas copy-on-write mode + assert as_compatible_data(data).flags.writeable + pd.reset_option("mode.copy_on_write") + + class VariableSubclassobjects(NamedArraySubclassobjects, ABC): @pytest.fixture def target(self, data): @@ -1201,7 +1216,8 @@ def test_as_variable(self): with pytest.raises(TypeError, match=r"without an explicit list of dimensions"): as_variable(data) - actual = as_variable(data, name="x") + with pytest.warns(FutureWarning, match="IndexVariable"): + actual = as_variable(data, name="x") assert_identical(expected.to_index_variable(), actual) actual = as_variable(0) @@ -1219,9 +1235,11 @@ def test_as_variable(self): # test datetime, timedelta conversion dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) for x in range(10)]) - assert as_variable(dt, "time").dtype.kind == "M" + with pytest.warns(FutureWarning, match="IndexVariable"): + assert as_variable(dt, "time").dtype.kind == "M" td = np.array([timedelta(days=x) for x in range(10)]) - assert as_variable(td, "time").dtype.kind == "m" + with pytest.warns(FutureWarning, match="IndexVariable"): + assert as_variable(td, "time").dtype.kind == "m" with pytest.raises(TypeError): as_variable(("x", DataArray([]))) diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index 3462af28663..b59dc36c108 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -19,6 +19,7 @@ MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" + # This file was generated using xarray.util.generate_aggregations. Do not edit manually. from __future__ import annotations @@ -245,13 +246,9 @@ def {method}( _FLOX_NOTES_TEMPLATE = """Use the ``flox`` package to significantly speed up {kind} computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. -The default choice is ``method="cohorts"`` which generalizes the best, -{recco} might work better for your problem. See the `flox documentation `_ for more.""" -_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="groupby", recco="other methods") -_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format( - kind="resampling", recco='``method="blockwise"``' -) +_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="groupby") +_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="resampling") ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") skipna = ExtraKwarg( @@ -300,11 +297,13 @@ def __init__( extra_kwargs=tuple(), numeric_only=False, see_also_modules=("numpy", "dask.array"), + min_flox_version=None, ): self.name = name self.extra_kwargs = extra_kwargs self.numeric_only = numeric_only self.see_also_modules = see_also_modules + self.min_flox_version = min_flox_version if bool_reduce: self.array_method = f"array_{name}" self.np_example_array = """ @@ -443,8 +442,8 @@ def generate_code(self, method, has_keep_attrs): if self.datastructure.numeric_only: extra_kwargs.append(f"numeric_only={method.numeric_only},") - # numpy_groupies & flox do not support median - # https://github.com/ml31415/numpy-groupies/issues/43 + # median isn't enabled yet, because it would break if a single group was present in multiple + # chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod") if method_is_not_flox_supported: indent = 12 @@ -465,11 +464,16 @@ def generate_code(self, method, has_keep_attrs): **kwargs, )""" - else: - return f"""\ + min_version_check = f""" + and module_available("flox", minversion="{method.min_flox_version}")""" + + return ( + """\ if ( flox_available - and OPTIONS["use_flox"] + and OPTIONS["use_flox"]""" + + (min_version_check if method.min_flox_version is not None else "") + + f""" and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( @@ -486,6 +490,7 @@ def generate_code(self, method, has_keep_attrs): keep_attrs=keep_attrs, **kwargs, )""" + ) class GenericAggregationGenerator(AggregationGenerator): @@ -522,7 +527,9 @@ def generate_code(self, method, has_keep_attrs): Method("sum", extra_kwargs=(skipna, min_count), numeric_only=True), Method("std", extra_kwargs=(skipna, ddof), numeric_only=True), Method("var", extra_kwargs=(skipna, ddof), numeric_only=True), - Method("median", extra_kwargs=(skipna,), numeric_only=True), + Method( + "median", extra_kwargs=(skipna,), numeric_only=True, min_flox_version="0.9.2" + ), # Cumulatives: Method("cumsum", extra_kwargs=(skipna,), numeric_only=True), Method("cumprod", extra_kwargs=(skipna,), numeric_only=True),