Skip to content

Commit

Permalink
fix: handle indexing changes in new xarray versions (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Mar 30, 2024
1 parent a3cb204 commit 7d07d30
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 14 deletions.
40 changes: 31 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -14,7 +14,7 @@ python = ">=3.10,<3.12"
poetry = "^1.6"
numpy = "^1.24"
pandas = "^2"
xarray = "<=2023.12"
xarray = ">=2023"
scikit-learn = "^1.0.2"
pooch = "^1.6.0"
tqdm = "^4.64.0"
Expand Down
6 changes: 6 additions & 0 deletions xeofs/models/_base_cross_model.py
Expand Up @@ -248,6 +248,12 @@ def inverse_transform(
Reconstructed data of right field.
"""
# Handle scalar mode in xr.dot
if "mode" not in scores1.dims:
scores1 = scores1.expand_dims("mode")
if "mode" not in scores2.dims:
scores2 = scores2.expand_dims("mode")

data1, data2 = self._inverse_transform_algorithm(scores1, scores2)

# Unstack and rescale the data
Expand Down
5 changes: 5 additions & 0 deletions xeofs/models/_base_model.py
Expand Up @@ -305,6 +305,11 @@ def inverse_transform(
if normalized:
norms = self.data["norms"].sel(mode=scores.mode)
scores = scores * norms

# Handle scalar mode in xr.dot
if "mode" not in scores.dims:
scores = scores.expand_dims("mode")

data_reconstructed = self._inverse_transform_algorithm(scores)

# Reconstructing the data using a single mode introduces a
Expand Down
13 changes: 9 additions & 4 deletions xeofs/preprocessing/transformer.py
Expand Up @@ -96,14 +96,19 @@ def _serialize_data(self, key: str, data: Data) -> DataSet:
ds = data
else:
# Convert DataArray to Dataset
coords = {}
data_vars = {}
if data.name in data.coords:
# Convert a coord-like DataArray to Dataset and note multiindexes
if isinstance(data.to_index(), pd.MultiIndex):
multiindexes[data.name] = [n for n in data.to_index().names]
# Make sure the DataArray has some name so we can create a string mapping
elif data.name is None:
data.name = key
ds = xr.Dataset(data_vars={data.name: data})
coords[data.name] = data
else:
# Make sure the DataArray has some name so we can create a string mapping
if data.name is None:
data.name = key
data_vars[data.name] = data
ds = xr.Dataset(data_vars=data_vars, coords=coords)
name_map = data.name

# Drop multiindexes and record for later
Expand Down

0 comments on commit 7d07d30

Please sign in to comment.