Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT allow metadata to be transformed in a Pipeline #28901

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

adrinjalali
Copy link
Member

Initial proposal: #28440 (comment)
xref: #28440 (comment)

This adds transform_input as a constructor argument to Pipeline, as:

    transform_input : list of str, default=None
        This enables transforming some input arguments to ``fit`` (other than ``X``)
        to be transformed by the steps of the pipeline up to the step which requires
        them. Requirement is defined via :ref:`metadata routing <metadata_routing>`.
        This can be used to pass a validation set through the pipeline for instance.

        See the example TBD for more details.

        You can only set this if metadata routing is enabled, which you
        can enable using ``sklearn.set_config(enable_metadata_routing=True)``.

It simply allows to transform metadata with fitted estimators up to the step which needs the metadata.

How does this look?

cc @lorentzenchr @ogrisel @amueller @betatim

Copy link

github-actions bot commented Apr 26, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 1622203. Link to the linter CI: here

@adrinjalali
Copy link
Member Author

So for simple cases where metadata is only used in fit and transform expects no metadata, we're fine. But things get a bit trickier when we start having a transform method of a step accepting the same metadata as the fit of that step.

Specifically, in this test:

@pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize("method", ["fit", "fit_transform"])
def test_transform_input_pipeline(method):
    """Test that with transform_input, data is correctly transformed for each step."""

    def get_transformer(registry, sample_weight, metadata):
        """Get a transformer with requests set."""
        return (
            ConsumingTransformer(registry=registry)
            .set_fit_request(sample_weight=sample_weight, metadata=metadata)
            .set_transform_request(sample_weight=sample_weight, metadata=metadata)
        )

    def get_pipeline():
        """Get a pipeline and corresponding registries.

        The pipeline has 4 steps, with different request values set to test different
        cases. One is aliased.
        """
        registry_1, registry_2, registry_3, registry_4 = (
            _Registry(),
            _Registry(),
            _Registry(),
            _Registry(),
        )
        pipe = make_pipeline(
            get_transformer(registry_1, sample_weight=True, metadata=True),
            get_transformer(registry_2, sample_weight=False, metadata=False),
            get_transformer(registry_3, sample_weight=True, metadata=True),
            get_transformer(registry_4, sample_weight="other_weights", metadata=True),
            transform_input=["sample_weight"],
        )
        return pipe, registry_1, registry_2, registry_3, registry_4

    def check_metadata(registry, methods, **metadata):
        """Check that the right metadata was recorded for the given methods."""
        assert registry
        for estimator in registry:
            for method in methods:
                check_recorded_metadata(
                    estimator,
                    method=method,
                    **metadata,
                )

    X = np.array([[1, 2], [3, 4]])
    y = np.array([0, 1])
    sample_weight = np.array([[1, 2]])
    other_weights = np.array([[30, 40]])
    metadata = np.array([[100, 200]])

    pipe, registry_1, registry_2, registry_3, registry_4 = get_pipeline()
    pipe.fit(
        X,
        y,
        sample_weight=sample_weight,
        other_weights=other_weights,
        metadata=metadata,
    )

    check_metadata(
        registry_1, ["fit", "transform"], sample_weight=sample_weight, metadata=metadata
    )
    check_metadata(registry_2, ["fit", "transform"])
    check_metadata(
        registry_3,
        ["fit", "transform"],
        sample_weight=sample_weight + 2,
        metadata=metadata,
    )
    check_metadata(
        registry_4,
        method.split("_"),  # ["fit", "transform"] if "fit_transform", ["fit"] otherwise
        sample_weight=other_weights + 3,
        metadata=metadata,
    )

Step 3 receives transformed data in its transform method during fit of the pipeline cause all metadata are transformed if they're in transform_input, but a second time when step3.transform is called, the metadata is not transformed (cause I haven't implemented it in pipeline.transform yet).

The question is, what should be the expected behavior?

Do we want transform_input to only transform when calling fit of sub estimators? That's a bit tricky cause all TransformerMixin estimators implement a fit_transform which accepts all metadata together, which means the metadata (if the same name) is either transformed or not transformed. (Wish we didn't have fit_transform in the first place, it's giving us so much headache)

@adrinjalali
Copy link
Member Author

Actually, in TransformerMixin we have:

        if _routing_enabled():
            transform_params = self.get_metadata_routing().consumes(
                method="transform", params=fit_params.keys()
            )
            if transform_params:
                warnings.warn(
                    (
                        f"This object ({self.__class__.__name__}) has a `transform`"
                        " method which consumes metadata, but `fit_transform` does not"
                        " forward metadata to `transform`. Please implement a custom"
                        " `fit_transform` method to forward metadata to `transform` as"
                        " well. Alternatively, you can explicitly do"
                        " `set_transform_request`and set all values to `False` to"
                        " disable metadata routed to `transform`, if that's an option."
                    ),
                    UserWarning,
                )

and we never send anything to .transform. So in Pipeline we can also assume things are only transformed for fit, as long as scikit-learn is concerned.

However, for third party transformers where they can have their own fit_transform and route parameters, then things can become tricky, as the example in the previous comment shows.

@adrinjalali
Copy link
Member Author

adrinjalali commented May 14, 2024

Another question is, do we want to have this syntactic sugar?

pipe = make_pipeline(
    StandardScaler(),
    HistGradientBoostingClassifier(..., early_stopping=True)
).fit(X, y, X_val, y_val)

The above code would:

  • early_stopping=True would change the default request values so that the user doesn't have to type .set_fit_request(X_val=True, y_val=True)
  • early_stopping=True sets something in the instance of the estimator which tells pipeline that X_val is of the same nature as X, and therefore should be transformed

It wouldn't change what we have now implemented in Pipeline in this PR, but would make it easier for the user. Not sure if it's too magical for us though.

For that to happen, HGBC need to have:

class HistGradientBoostingClassifier(...):
    ...

    def get_metadata_routing(self):
        routing = super().get_metadata_routing()
        if self.early_stopping:
            routing.fit.add(X_val=True, y_val=True)

    def __sklearn_get_transforming_data__(self):
        return ["X_val"]

And Pipeline would look for info in __sklearn_get_transforming_data__ if it exists.

cc @glemaitre

It goes towards the direction of having more default routing info as @ogrisel really likes. (ref #26179 )

Note that this could come later separately as an enhancement to this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant