Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle indexing changes in new xarray versions #159

Merged
merged 2 commits into from Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 31 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -14,7 +14,7 @@ python = ">=3.10,<3.12"
poetry = "^1.6"
numpy = "^1.24"
pandas = "^2"
xarray = "<=2023.12"
xarray = ">=2023"
scikit-learn = "^1.0.2"
pooch = "^1.6.0"
tqdm = "^4.64.0"
Expand Down
5 changes: 4 additions & 1 deletion xeofs/models/eof.py
Expand Up @@ -145,7 +145,10 @@ def _inverse_transform_algorithm(self, scores: DataArray) -> DataArray:
Reconstructed data.

"""
# Reconstruct the data
# Handle scalar mode
if "mode" not in scores.dims:
scores = scores.expand_dims("mode")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's better to add the "mode" dimensions in the _BaseModel class, since we also remove the "mode" dim there afterward? Plus other models can inherit this behaviour :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, good idea, I had only followed the traceback to where I hit the error and added a quick fix.

comps = self.data["components"].sel(mode=scores.mode)

reconstructed_data = xr.dot(comps.conj(), scores, dims="mode")
Expand Down
6 changes: 6 additions & 0 deletions xeofs/models/mca.py
Expand Up @@ -279,6 +279,12 @@ def _inverse_transform_algorithm(
Reconstructed data of right field.

"""
# Handle scalar mode
if "mode" not in scores1.dims:
scores1 = scores1.expand_dims("mode")
if "mode" not in scores2.dims:
scores2 = scores2.expand_dims("mode")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here like in the eof.py module - probably better to move it to inverse_transform within the _base_cross_model.py.

# Singular vectors
comps1 = self.data["components1"].sel(mode=scores1.mode)
comps2 = self.data["components2"].sel(mode=scores2.mode)
Expand Down
13 changes: 9 additions & 4 deletions xeofs/preprocessing/transformer.py
Expand Up @@ -96,14 +96,19 @@ def _serialize_data(self, key: str, data: Data) -> DataSet:
ds = data
else:
# Convert DataArray to Dataset
coords = {}
data_vars = {}
if data.name in data.coords:
# Convert a coord-like DataArray to Dataset and note multiindexes
if isinstance(data.to_index(), pd.MultiIndex):
multiindexes[data.name] = [n for n in data.to_index().names]
# Make sure the DataArray has some name so we can create a string mapping
elif data.name is None:
data.name = key
ds = xr.Dataset(data_vars={data.name: data})
coords[data.name] = data
else:
# Make sure the DataArray has some name so we can create a string mapping
if data.name is None:
data.name = key
data_vars[data.name] = data
ds = xr.Dataset(data_vars=data_vars, coords=coords)
name_map = data.name

# Drop multiindexes and record for later
Expand Down