Skip to content

Commit

Permalink
API Standardize X as inverse_transform input parameter (#28756)
Browse files Browse the repository at this point in the history
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
  • Loading branch information
wd60622 and jeremiedbb committed Apr 29, 2024
1 parent 19c068f commit 0bdc754
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 73 deletions.
11 changes: 11 additions & 0 deletions doc/whats_new/v1.5.rst
Expand Up @@ -56,6 +56,17 @@ Changed models
signs across all `PCA` solvers, including the new
`svd_solver="covariance_eigh"` option introduced in this release.

Changes impacting many modules
------------------------------

- |API| The name of the input of the `inverse_transform` method of estimators has been
standardized to `X`. As a consequence, `Xt` is deprecated and will be removed in
version 1.7 in the following estimators: :class:`cluster.FeatureAgglomeration`,
:class:`decomposition.MiniBatchNMF`, :class:`decomposition.NMF`,
:class:`model_selection.GridSearchCV`, :class:`model_selection.RandomizedSearchCV`,
:class:`pipeline.Pipeline` and :class:`preprocessing.KBinsDiscretizer`.
:pr:`28756` by :user:`Will Dean <wd60622>`.

Support for Array API
---------------------

Expand Down
37 changes: 12 additions & 25 deletions sklearn/cluster/_feature_agglomeration.py
Expand Up @@ -6,13 +6,13 @@
# Author: V. Michel, A. Gramfort
# License: BSD 3 clause

import warnings

import numpy as np
from scipy.sparse import issparse

from ..base import TransformerMixin
from ..utils import metadata_routing
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
from ..utils.validation import check_is_fitted

###############################################################################
Expand All @@ -25,9 +25,9 @@ class AgglomerationTransform(TransformerMixin):
"""

# This prevents ``set_split_inverse_transform`` to be generated for the
# non-standard ``Xred`` arg on ``inverse_transform``.
# TODO(1.5): remove when Xred is removed for inverse_transform.
__metadata_request__inverse_transform = {"Xred": metadata_routing.UNUSED}
# non-standard ``Xt`` arg on ``inverse_transform``.
# TODO(1.7): remove when Xt is removed for inverse_transform.
__metadata_request__inverse_transform = {"Xt": metadata_routing.UNUSED}

def transform(self, X):
"""
Expand Down Expand Up @@ -63,43 +63,30 @@ def transform(self, X):
nX = np.array(nX).T
return nX

def inverse_transform(self, Xt=None, Xred=None):
def inverse_transform(self, X=None, *, Xt=None):
"""
Inverse the transformation and return a vector of size `n_features`.
Parameters
----------
Xt : array-like of shape (n_samples, n_clusters) or (n_clusters,)
X : array-like of shape (n_samples, n_clusters) or (n_clusters,)
The values to be assigned to each cluster of samples.
Xred : deprecated
Use `Xt` instead.
Xt : array-like of shape (n_samples, n_clusters) or (n_clusters,)
The values to be assigned to each cluster of samples.
.. deprecated:: 1.3
.. deprecated:: 1.5
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
Returns
-------
X : ndarray of shape (n_samples, n_features) or (n_features,)
A vector of size `n_samples` with the values of `Xred` assigned to
each of the cluster of samples.
"""
if Xt is None and Xred is None:
raise TypeError("Missing required positional argument: Xt")

if Xred is not None and Xt is not None:
raise ValueError("Please provide only `Xt`, and not `Xred`.")

if Xred is not None:
warnings.warn(
(
"Input argument `Xred` was renamed to `Xt` in v1.3 and will be"
" removed in v1.5."
),
FutureWarning,
)
Xt = Xred
X = _deprecate_Xt_in_inverse_transform(X, Xt)

check_is_fitted(self)

unil, inverse = np.unique(self.labels_, return_inverse=True)
return Xt[..., inverse]
return X[..., inverse]
16 changes: 8 additions & 8 deletions sklearn/cluster/tests/test_feature_agglomeration.py
Expand Up @@ -59,23 +59,23 @@ def test_feature_agglomeration_feature_names_out():
)


# TODO(1.5): remove this test
def test_inverse_transform_Xred_deprecation():
# TODO(1.7): remove this test
def test_inverse_transform_Xt_deprecation():
X = np.array([0, 0, 1]).reshape(1, 3) # (n_samples, n_features)

est = FeatureAgglomeration(n_clusters=1, pooling_func=np.mean)
est.fit(X)
Xt = est.transform(X)
X = est.transform(X)

with pytest.raises(TypeError, match="Missing required positional argument"):
est.inverse_transform()

with pytest.raises(ValueError, match="Please provide only"):
est.inverse_transform(Xt=Xt, Xred=Xt)
with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only."):
est.inverse_transform(X=X, Xt=X)

with warnings.catch_warnings(record=True):
warnings.simplefilter("error")
est.inverse_transform(Xt)
est.inverse_transform(X)

with pytest.warns(FutureWarning, match="Input argument `Xred` was renamed to `Xt`"):
est.inverse_transform(Xred=Xt)
with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"):
est.inverse_transform(Xt=X)
29 changes: 9 additions & 20 deletions sklearn/decomposition/_nmf.py
Expand Up @@ -32,6 +32,7 @@
StrOptions,
validate_params,
)
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
from ..utils.validation import (
check_is_fitted,
Expand Down Expand Up @@ -1310,44 +1311,32 @@ def fit(self, X, y=None, **params):
self.fit_transform(X, **params)
return self

def inverse_transform(self, Xt=None, W=None):
def inverse_transform(self, X=None, *, Xt=None):
"""Transform data back to its original space.
.. versionadded:: 0.18
Parameters
----------
Xt : {ndarray, sparse matrix} of shape (n_samples, n_components)
X : {ndarray, sparse matrix} of shape (n_samples, n_components)
Transformed data matrix.
W : deprecated
Use `Xt` instead.
Xt : {ndarray, sparse matrix} of shape (n_samples, n_components)
Transformed data matrix.
.. deprecated:: 1.3
.. deprecated:: 1.5
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
Returns
-------
X : ndarray of shape (n_samples, n_features)
Returns a data matrix of the original shape.
"""
if Xt is None and W is None:
raise TypeError("Missing required positional argument: Xt")

if W is not None and Xt is not None:
raise ValueError("Please provide only `Xt`, and not `W`.")

if W is not None:
warnings.warn(
(
"Input argument `W` was renamed to `Xt` in v1.3 and will be removed"
" in v1.5."
),
FutureWarning,
)
Xt = W
X = _deprecate_Xt_in_inverse_transform(X, Xt)

check_is_fitted(self)
return Xt @ self.components_
return X @ self.components_

@property
def _n_features_out(self):
Expand Down
21 changes: 11 additions & 10 deletions sklearn/decomposition/tests/test_nmf.py
Expand Up @@ -933,30 +933,31 @@ def test_minibatch_nmf_verbose():
sys.stdout = old_stdout


# TODO(1.5): remove this test
def test_NMF_inverse_transform_W_deprecation():
rng = np.random.mtrand.RandomState(42)
# TODO(1.7): remove this test
@pytest.mark.parametrize("Estimator", [NMF, MiniBatchNMF])
def test_NMF_inverse_transform_Xt_deprecation(Estimator):
rng = np.random.RandomState(42)
A = np.abs(rng.randn(6, 5))
est = NMF(
est = Estimator(
n_components=3,
init="random",
random_state=0,
tol=1e-6,
)
Xt = est.fit_transform(A)
X = est.fit_transform(A)

with pytest.raises(TypeError, match="Missing required positional argument"):
est.inverse_transform()

with pytest.raises(ValueError, match="Please provide only"):
est.inverse_transform(Xt=Xt, W=Xt)
with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"):
est.inverse_transform(X=X, Xt=X)

with warnings.catch_warnings(record=True):
warnings.simplefilter("error")
est.inverse_transform(Xt)
est.inverse_transform(X)

with pytest.warns(FutureWarning, match="Input argument `W` was renamed to `Xt`"):
est.inverse_transform(W=Xt)
with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"):
est.inverse_transform(Xt=X)


@pytest.mark.parametrize("Estimator", [NMF, MiniBatchNMF])
Expand Down
13 changes: 11 additions & 2 deletions sklearn/model_selection/_search.py
Expand Up @@ -36,6 +36,7 @@
from ..utils._estimator_html_repr import _VisualBlock
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ..utils._tags import _safe_tags
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
from ..utils.metadata_routing import (
MetadataRouter,
MethodMapping,
Expand Down Expand Up @@ -637,26 +638,34 @@ def transform(self, X):
return self.best_estimator_.transform(X)

@available_if(_estimator_has("inverse_transform"))
def inverse_transform(self, Xt):
def inverse_transform(self, X=None, Xt=None):
"""Call inverse_transform on the estimator with the best found params.
Only available if the underlying estimator implements
``inverse_transform`` and ``refit=True``.
Parameters
----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
Xt : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
.. deprecated:: 1.5
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
Returns
-------
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
Result of the `inverse_transform` function for `Xt` based on the
estimator with the best found parameters.
"""
X = _deprecate_Xt_in_inverse_transform(X, Xt)
check_is_fitted(self)
return self.best_estimator_.inverse_transform(Xt)
return self.best_estimator_.inverse_transform(X)

@property
def n_features_in_(self):
Expand Down
23 changes: 23 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Expand Up @@ -3,6 +3,7 @@
import pickle
import re
import sys
import warnings
from collections.abc import Iterable, Sized
from functools import partial
from io import StringIO
Expand Down Expand Up @@ -2553,6 +2554,28 @@ def test_search_html_repr():
assert "<pre>LogisticRegression()</pre>" in repr_html


# TODO(1.7): remove this test
@pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
def test_inverse_transform_Xt_deprecation(SearchCV):
clf = MockClassifier()
search = SearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)

X2 = search.fit(X, y).transform(X)

with pytest.raises(TypeError, match="Missing required positional argument"):
search.inverse_transform()

with pytest.raises(TypeError, match="Cannot use both X and Xt. Use X only"):
search.inverse_transform(X=X2, Xt=X2)

with warnings.catch_warnings(record=True):
warnings.simplefilter("error")
search.inverse_transform(X2)

with pytest.warns(FutureWarning, match="Xt was renamed X in version 1.5"):
search.inverse_transform(Xt=X2)


# Metadata Routing Tests
# ======================

Expand Down
20 changes: 15 additions & 5 deletions sklearn/pipeline.py
Expand Up @@ -29,6 +29,7 @@
)
from .utils._tags import _safe_tags
from .utils._user_interface import _print_elapsed_time
from .utils.deprecation import _deprecate_Xt_in_inverse_transform
from .utils.metadata_routing import (
MetadataRouter,
MethodMapping,
Expand Down Expand Up @@ -909,19 +910,28 @@ def _can_inverse_transform(self):
return all(hasattr(t, "inverse_transform") for _, _, t in self._iter())

@available_if(_can_inverse_transform)
def inverse_transform(self, Xt, **params):
def inverse_transform(self, X=None, *, Xt=None, **params):
"""Apply `inverse_transform` for each step in a reverse order.
All estimators in the pipeline must support `inverse_transform`.
Parameters
----------
X : array-like of shape (n_samples, n_transformed_features)
Data samples, where ``n_samples`` is the number of samples and
``n_features`` is the number of features. Must fulfill
input requirements of last step of pipeline's
``inverse_transform`` method.
Xt : array-like of shape (n_samples, n_transformed_features)
Data samples, where ``n_samples`` is the number of samples and
``n_features`` is the number of features. Must fulfill
input requirements of last step of pipeline's
``inverse_transform`` method.
.. deprecated:: 1.5
`Xt` was deprecated in 1.5 and will be removed in 1.7. Use `X` instead.
**params : dict of str -> object
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
Expand All @@ -940,15 +950,15 @@ def inverse_transform(self, Xt, **params):
"""
_raise_for_params(params, self, "inverse_transform")

X = _deprecate_Xt_in_inverse_transform(X, Xt)

# we don't have to branch here, since params is only non-empty if
# enable_metadata_routing=True.
routed_params = process_routing(self, "inverse_transform", **params)
reverse_iter = reversed(list(self._iter()))
for _, name, transform in reverse_iter:
Xt = transform.inverse_transform(
Xt, **routed_params[name].inverse_transform
)
return Xt
X = transform.inverse_transform(X, **routed_params[name].inverse_transform)
return X

@available_if(_final_estimator_has("score"))
def score(self, X, y=None, sample_weight=None, **params):
Expand Down

0 comments on commit 0bdc754

Please sign in to comment.