Skip to content

Commit

Permalink
API Implements get_feature_names_out for transformers that support ge…
Browse files Browse the repository at this point in the history
…t_feature_names (scikit-learn#18444)

Co-authored-by: Andreas Mueller <andreas.mueller@columbia.edu>
Co-authored-by: Andreas Mueller <andreasmuellerml@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
  • Loading branch information
6 people committed Sep 7, 2021
1 parent d70cd15 commit 4d1e176
Show file tree
Hide file tree
Showing 33 changed files with 1,446 additions and 160 deletions.
12 changes: 12 additions & 0 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ Class APIs and Estimator Types
* :term:`fit`
* :term:`transform`
* :term:`get_feature_names`
* :term:`get_feature_names_out`

meta-estimator
meta-estimators
Expand Down Expand Up @@ -1262,6 +1263,17 @@ Methods
to the names of input columns from which output column names can
be generated. By default input features are named x0, x1, ....

``get_feature_names_out``
Primarily for :term:`feature extractors`, but also used for other
transformers to provide string names for each column in the output of
the estimator's :term:`transform` method. It outputs an array of
strings and may take an array-like of strings as input, corresponding
to the names of input columns from which output column names can
be generated. If `input_features` is not passed in, then the
`feature_names_in_` attribute will be used. If the
`feature_names_in_` attribute is not defined, then the
input names are named `[x0, x1, ..., x(n_features_in_)]`.

``get_n_splits``
On a :term:`CV splitter` (not an estimator), returns the number of
elements one would get if iterating through the return value of
Expand Down
38 changes: 29 additions & 9 deletions doc/modules/compose.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,27 @@ or by name::
>>> pipe['reduce_dim']
PCA()

To enable model inspection, :class:`~sklearn.pipeline.Pipeline` has a
``get_feature_names_out()`` method, just like all transformers. You can use
pipeline slicing to get the feature names going into each step::

>>> from sklearn.datasets import load_iris
>>> from sklearn.feature_selection import SelectKBest
>>> iris = load_iris()
>>> pipe = Pipeline(steps=[
... ('select', SelectKBest(k=2)),
... ('clf', LogisticRegression())])
>>> pipe.fit(iris.data, iris.target)
Pipeline(steps=[('select', SelectKBest(...)), ('clf', LogisticRegression(...))])
>>> pipe[:-1].get_feature_names_out()
array(['x2', 'x3'], ...)

You can also provide custom feature names for the input data using
``get_feature_names_out``::

>>> pipe[:-1].get_feature_names_out(iris.feature_names)
array(['petal length (cm)', 'petal width (cm)'], ...)

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_feature_selection_plot_feature_selection_pipeline.py`
Expand Down Expand Up @@ -426,21 +447,20 @@ By default, the remaining rating columns are ignored (``remainder='drop'``)::
>>> from sklearn.feature_extraction.text import CountVectorizer
>>> from sklearn.preprocessing import OneHotEncoder
>>> column_trans = ColumnTransformer(
... [('city_category', OneHotEncoder(dtype='int'),['city']),
... [('categories', OneHotEncoder(dtype='int'), ['city']),
... ('title_bow', CountVectorizer(), 'title')],
... remainder='drop')
... remainder='drop', prefix_feature_names_out=False)

>>> column_trans.fit(X)
ColumnTransformer(transformers=[('city_category', OneHotEncoder(dtype='int'),
ColumnTransformer(prefix_feature_names_out=False,
transformers=[('categories', OneHotEncoder(dtype='int'),
['city']),
('title_bow', CountVectorizer(), 'title')])

>>> column_trans.get_feature_names()
['city_category__x0_London', 'city_category__x0_Paris', 'city_category__x0_Sallisaw',
'title_bow__bow', 'title_bow__feast', 'title_bow__grapes', 'title_bow__his',
'title_bow__how', 'title_bow__last', 'title_bow__learned', 'title_bow__moveable',
'title_bow__of', 'title_bow__the', 'title_bow__trick', 'title_bow__watson',
'title_bow__wrath']
>>> column_trans.get_feature_names_out()
array(['city_London', 'city_Paris', 'city_Sallisaw', 'bow', 'feast',
'grapes', 'his', 'how', 'last', 'learned', 'moveable', 'of', 'the',
'trick', 'watson', 'wrath'], ...)

>>> column_trans.transform(X).toarray()
array([[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
Expand Down
44 changes: 20 additions & 24 deletions doc/modules/feature_extraction.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. _feature_extraction:
.. _feature_extraction:

==================
Feature extraction
Expand Down Expand Up @@ -53,8 +53,8 @@ is a traditional numerical feature::
[ 0., 1., 0., 12.],
[ 0., 0., 1., 18.]])

>>> vec.get_feature_names()
['city=Dubai', 'city=London', 'city=San Francisco', 'temperature']
>>> vec.get_feature_names_out()
array(['city=Dubai', 'city=London', 'city=San Francisco', 'temperature'], ...)

:class:`DictVectorizer` accepts multiple string values for one
feature, like, e.g., multiple categories for a movie.
Expand All @@ -69,10 +69,9 @@ and its year of release.
array([[0.000e+00, 1.000e+00, 0.000e+00, 1.000e+00, 2.003e+03],
[1.000e+00, 0.000e+00, 1.000e+00, 0.000e+00, 2.011e+03],
[0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.974e+03]])
>>> vec.get_feature_names() == ['category=animation', 'category=drama',
... 'category=family', 'category=thriller',
... 'year']
True
>>> vec.get_feature_names_out()
array(['category=animation', 'category=drama', 'category=family',
'category=thriller', 'year'], ...)
>>> vec.transform({'category': ['thriller'],
... 'unseen_feature': '3'}).toarray()
array([[0., 0., 0., 1., 0.]])
Expand Down Expand Up @@ -111,8 +110,9 @@ suitable for feeding into a classifier (maybe after being piped into a
with 6 stored elements in Compressed Sparse ... format>
>>> pos_vectorized.toarray()
array([[1., 1., 1., 1., 1., 1.]])
>>> vec.get_feature_names()
['pos+1=PP', 'pos-1=NN', 'pos-2=DT', 'word+1=on', 'word-1=cat', 'word-2=the']
>>> vec.get_feature_names_out()
array(['pos+1=PP', 'pos-1=NN', 'pos-2=DT', 'word+1=on', 'word-1=cat',
'word-2=the'], ...)

As you can imagine, if one extracts such a context around each individual
word of a corpus of documents the resulting matrix will be very wide
Expand Down Expand Up @@ -340,10 +340,9 @@ Each term found by the analyzer during the fit is assigned a unique
integer index corresponding to a column in the resulting matrix. This
interpretation of the columns can be retrieved as follows::

>>> vectorizer.get_feature_names() == (
... ['and', 'document', 'first', 'is', 'one',
... 'second', 'the', 'third', 'this'])
True
>>> vectorizer.get_feature_names_out()
array(['and', 'document', 'first', 'is', 'one', 'second', 'the',
'third', 'this'], ...)

>>> X.toarray()
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
Expand Down Expand Up @@ -406,8 +405,8 @@ however, similar words are useful for prediction, such as in classifying
writing style or personality.

There are several known issues in our provided 'english' stop word list. It
does not aim to be a general, 'one-size-fits-all' solution as some tasks
may require a more custom solution. See [NQY18]_ for more details.
does not aim to be a general, 'one-size-fits-all' solution as some tasks
may require a more custom solution. See [NQY18]_ for more details.

Please take care in choosing a stop word list.
Popular stop word lists may include words that are highly informative to
Expand Down Expand Up @@ -742,9 +741,8 @@ decide better::

>>> ngram_vectorizer = CountVectorizer(analyzer='char_wb', ngram_range=(2, 2))
>>> counts = ngram_vectorizer.fit_transform(['words', 'wprds'])
>>> ngram_vectorizer.get_feature_names() == (
... [' w', 'ds', 'or', 'pr', 'rd', 's ', 'wo', 'wp'])
True
>>> ngram_vectorizer.get_feature_names_out()
array([' w', 'ds', 'or', 'pr', 'rd', 's ', 'wo', 'wp'], ...)
>>> counts.toarray().astype(int)
array([[1, 1, 1, 0, 1, 1, 1, 0],
[1, 1, 0, 1, 1, 1, 0, 1]])
Expand All @@ -758,17 +756,15 @@ span across words::
>>> ngram_vectorizer.fit_transform(['jumpy fox'])
<1x4 sparse matrix of type '<... 'numpy.int64'>'
with 4 stored elements in Compressed Sparse ... format>
>>> ngram_vectorizer.get_feature_names() == (
... [' fox ', ' jump', 'jumpy', 'umpy '])
True
>>> ngram_vectorizer.get_feature_names_out()
array([' fox ', ' jump', 'jumpy', 'umpy '], ...)

>>> ngram_vectorizer = CountVectorizer(analyzer='char', ngram_range=(5, 5))
>>> ngram_vectorizer.fit_transform(['jumpy fox'])
<1x5 sparse matrix of type '<... 'numpy.int64'>'
with 5 stored elements in Compressed Sparse ... format>
>>> ngram_vectorizer.get_feature_names() == (
... ['jumpy', 'mpy f', 'py fo', 'umpy ', 'y fox'])
True
>>> ngram_vectorizer.get_feature_names_out()
array(['jumpy', 'mpy f', 'py fo', 'umpy ', 'y fox'], ...)

The word boundaries-aware variant ``char_wb`` is especially interesting
for languages that use white-spaces for word separation as it generates
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ Changelog
- |API| `np.matrix` usage is deprecated in 1.0 and will raise a `TypeError` in
1.2. :pr:`20165` by `Thomas Fan`_.

- |API| :term:`get_feature_names_out` has been added to the transformer API
to get the names of the output features. :pr:`18444` by `Thomas Fan`_.

- |API| All estimators store `feature_names_in_` when fitted on pandas Dataframes.
These feature names are compared to names seen in `non-fit` methods,
`i.e.` `transform` and will raise a `FutureWarning` if they are not consistent.
Expand Down Expand Up @@ -225,6 +228,10 @@ Changelog
:mod:`sklearn.compose`
......................

- |API| Adds `prefix_feature_names_out` to :class:`compose.ColumnTransformer`.
This flag controls the prefixing of feature names out in
:term:`get_feature_names_out`. :pr:`18444` by `Thomas Fan`_.

- |Enhancement| :class:`compose.ColumnTransformer` now records the output
of each transformer in `output_indices_`. :pr:`18393` by
:user:`Luca Bittarello <lbittarello>`.
Expand Down
6 changes: 3 additions & 3 deletions examples/applications/plot_topics_extraction_with_nmf_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def plot_top_words(model, feature_names, n_top_words, title):
print("done in %0.3fs." % (time() - t0))


tfidf_feature_names = tfidf_vectorizer.get_feature_names()
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
plot_top_words(nmf, tfidf_feature_names, n_top_words,
'Topics in NMF model (Frobenius norm)')

Expand All @@ -117,7 +117,7 @@ def plot_top_words(model, feature_names, n_top_words, title):
l1_ratio=.5).fit(tfidf)
print("done in %0.3fs." % (time() - t0))

tfidf_feature_names = tfidf_vectorizer.get_feature_names()
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
plot_top_words(nmf, tfidf_feature_names, n_top_words,
'Topics in NMF model (generalized Kullback-Leibler divergence)')

Expand All @@ -132,5 +132,5 @@ def plot_top_words(model, feature_names, n_top_words, title):
lda.fit(tf)
print("done in %0.3fs." % (time() - t0))

tf_feature_names = tf_vectorizer.get_feature_names()
tf_feature_names = tf_vectorizer.get_feature_names_out()
plot_top_words(lda, tf_feature_names, n_top_words, 'Topics in LDA model')
2 changes: 1 addition & 1 deletion examples/bicluster/plot_bicluster_newsgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def build_tokenizer(self):
time() - start_time,
v_measure_score(y_kmeans, y_true)))

feature_names = vectorizer.get_feature_names()
feature_names = vectorizer.get_feature_names_out()
document_names = list(newsgroups.target_names[i] for i in newsgroups.target)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@
numerical_columns = ["EDUCATION", "EXPERIENCE", "AGE"]

preprocessor = make_column_transformer(
(OneHotEncoder(drop="if_binary"), categorical_columns), remainder="passthrough"
(OneHotEncoder(drop="if_binary"), categorical_columns),
remainder="passthrough",
prefix_feature_names_out=False,
)

# %%
Expand Down Expand Up @@ -199,13 +201,7 @@
#
# First of all, we can take a look to the values of the coefficients of the
# regressor we have fitted.

feature_names = (
model.named_steps["columntransformer"]
.named_transformers_["onehotencoder"]
.get_feature_names(input_features=categorical_columns)
)
feature_names = np.concatenate([feature_names, numerical_columns])
feature_names = model[:-1].get_feature_names_out()

coefs = pd.DataFrame(
model.named_steps["transformedtargetregressor"].regressor_.coef_,
Expand Down
2 changes: 1 addition & 1 deletion examples/inspection/plot_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
# capacity).
ohe = (rf.named_steps['preprocess']
.named_transformers_['cat'])
feature_names = ohe.get_feature_names(input_features=categorical_columns)
feature_names = ohe.get_feature_names_out(categorical_columns)
feature_names = np.r_[feature_names, numerical_columns]

tree_feature_importances = (
Expand Down
10 changes: 3 additions & 7 deletions examples/text/plot_document_classification_20newsgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def size_mb(docs):
if opts.use_hashing:
feature_names = None
else:
feature_names = vectorizer.get_feature_names()
feature_names = vectorizer.get_feature_names_out()

if opts.select_chi2:
print("Extracting %d best features by a chi-squared test" %
Expand All @@ -183,16 +183,12 @@ def size_mb(docs):
ch2 = SelectKBest(chi2, k=opts.select_chi2)
X_train = ch2.fit_transform(X_train, y_train)
X_test = ch2.transform(X_test)
if feature_names:
if feature_names is not None:
# keep selected feature names
feature_names = [feature_names[i] for i
in ch2.get_support(indices=True)]
feature_names = feature_names[ch2.get_support()]
print("done in %fs" % (time() - t0))
print()

if feature_names:
feature_names = np.asarray(feature_names)


def trim(s):
"""Trim string to fit on terminal (assuming 80-column display)"""
Expand Down
2 changes: 1 addition & 1 deletion examples/text/plot_document_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def is_interactive():
else:
order_centroids = km.cluster_centers_.argsort()[:, ::-1]

terms = vectorizer.get_feature_names()
terms = vectorizer.get_feature_names_out()
for i in range(true_k):
print("Cluster %d:" % i, end='')
for ind in order_centroids[i, :10]:
Expand Down
2 changes: 1 addition & 1 deletion examples/text/plot_hashing_vs_dict_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def token_freqs(doc):
vectorizer.fit_transform(token_freqs(d) for d in raw_data)
duration = time() - t0
print("done in %fs at %0.3fMB/s" % (duration, data_size_mb / duration))
print("Found %d unique terms" % len(vectorizer.get_feature_names()))
print("Found %d unique terms" % len(vectorizer.get_feature_names_out()))
print()

print("FeatureHasher on frequency dicts")
Expand Down
30 changes: 30 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils.validation import check_array
from .utils.validation import _check_y
from .utils.validation import _num_features
from .utils.validation import _check_feature_names_in
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _get_feature_names

Expand Down Expand Up @@ -846,6 +847,35 @@ def fit_transform(self, X, y=None, **fit_params):
return self.fit(X, y, **fit_params).transform(X)


class _OneToOneFeatureMixin:
"""Provides `get_feature_names_out` for simple transformers.
Assumes there's a 1-to-1 correspondence between input features
and output features.
"""

def get_feature_names_out(self, input_features=None):
"""Get output feature names for transformation.
Parameters
----------
input_features : array-like of str or None, default=None
Input features.
- If `input_features` is `None`, then `feature_names_in_` is
used as feature names in. If `feature_names_in_` is not defined,
then names are generated: `[x0, x1, ..., x(n_features_in_)]`.
- If `input_features` is an array-like, then `input_features` must
match `feature_names_in_` if `feature_names_in_` is defined.
Returns
-------
feature_names_out : ndarray of str objects
Same as input features.
"""
return _check_feature_names_in(self, input_features)


class DensityMixin:
"""Mixin class for all density estimators in scikit-learn."""

Expand Down

0 comments on commit 4d1e176

Please sign in to comment.