Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - check datafit + penalty compatibility with solver #137

Draft
wants to merge 53 commits into
base: main
Choose a base branch
from

Conversation

PABannier
Copy link
Collaborator

@PABannier PABannier commented Dec 10, 2022

A quick proof-of-concept of a function that checks if the combination (solver, datafit, penalty) is supported. Currently we have some edge cases where one can pass ProxNewton solver with L0_5 penalty without any error being raised.

Pros of this design: the validation rules are centralized and validating a 3-uple is a one-liner in glm_fit.
Cons: we have to update the rules as we enhance the capabilities of the solver.

All in all, I think it is very valuable to have more verbose errors when fitting estimators (e.g. Ali Rahimi initially passed a combination Quadratic, L2_3, ProxNewton which cannot be optimized at the moment of writing).

Closes #101
Closes #90
Closes #109

@PABannier
Copy link
Collaborator Author

PABannier commented Dec 10, 2022

With this PR, the errors are more verbose:

In [1]: from skglm.estimators import GeneralizedLinearEstimator
           from skglm.penalties import L0_5
           from skglm.datafits import Quadratic, Logistic
           from skglm.solvers import ProxNewton, AndersonCD
           import numpy as np

In [2]: X = np.random.normal(0, 1, (30, 50))
           y = np.random.normal(0, 1, (30,))

In [3]: clf = GeneralizedLinearEstimator(Quadratic(), L0_5(1.), ProxNewton())

In [4]: clf.fit(X, y)
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 clf.fit(X, y)

File ~/Documents/skglm/skglm/estimators.py:241, in GeneralizedLinearEstimator.fit(self, X, y)
    238 self.datafit = self.datafit if self.datafit else Quadratic()
    239 self.solver = self.solver if self.solver else AndersonCD()
--> 241 return _glm_fit(X, y, self, self.datafit, self.penalty, self.solver)

File ~/Documents/skglm/skglm/estimators.py:29, in _glm_fit(X, y, model, datafit, penalty, solver)
     27 is_classif = isinstance(datafit, (Logistic, QuadraticSVC))
     28 fit_intercept = solver.fit_intercept
---> 29 validate_solver(solver, datafit, penalty)
     31 if is_classif:
     32     check_classification_targets(y)

File ~/Documents/skglm/skglm/utils/dispatcher.py:21, in validate_solver(solver, datafit, penalty)
      6 """Ensure the solver is suited for the `datafit` + `penalty` problem.
      7
      8 Parameters
   (...)
     17     Penalty.
     18 """
     19 if (isinstance(solver, ProxNewton)
     20     and not set(("raw_grad", "raw_hessian")) <= set(dir(datafit))):
---> 21     raise Exception(
     22         f"ProwNewton cannot optimize {datafit.__class__.__name__}, since `raw_grad`"
     23         " and `raw_hessian` are not implemented.")
     24 if ("ws_strategy" in dir(solver) and solver.ws_strategy == "subdiff"
     25     and isinstance(penalty, (L0_5, L2_3))):
     26     raise Exception(
     27         "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). "
     28         "Set ws_strategy to `fixpoint`.")

Exception: ProwNewton cannot optimize Quadratic, since `raw_grad` and `raw_hessian` are not implemented.

@PABannier PABannier changed the title POC Add validation logic passing datafit, penalty and solver POC Add validation logic when passing datafit, penalty and solver to _glm_fit Dec 10, 2022
@PABannier PABannier marked this pull request as draft December 11, 2022 19:02
@mathurinm
Copy link
Collaborator

Looks nice @PABannier, this will definitely improve UX!

From an API point of view, shouldn't this check be delegated to each solver? This way we don't have one big function, but Solver.validate(datafit, penalty), in the spirit of what @Badr-MOUFAD implemented here : https://github.com/scikit-learn-contrib/skglm/blob/main/skglm/experimental/pdcd_ws.py#L201

Such functions could also take care of the initialization (e.g. stepsize computation) which is done on a solver basis. WDYT?

@PABannier
Copy link
Collaborator Author

@mathurinm Yes I think it's cleaner, currently refining the POC.

@mathurinm
Copy link
Collaborator

This would be a nice addition if we can ship it in the 0.3 release @Badr-MOUFAD , given that we added a few datafits, penalties and solvers !

@mathurinm
Copy link
Collaborator

@Badr-MOUFAD the issue popped up in #188, do you have time to take this over ? A simple check, at the beginning of each solver, that the datafit and penalty are supported (eg AndersonCD does not support Gamma datafit)

@Badr-MOUFAD
Copy link
Collaborator

@Badr-MOUFAD the issue popped up in #188, do you have time to take this over ? A simple check, at the beginning of each solver, that the datafit and penalty are supported (eg AndersonCD does not support Gamma datafit)

Sure, I will resume this PR.

@Badr-MOUFAD Badr-MOUFAD changed the title POC Add validation logic when passing datafit, penalty and solver to _glm_fit ENH - check datafit + penalty compatibility with solver Oct 18, 2023
@mathurinm
Copy link
Collaborator

Requires #191 to be implemented fit to allow for better checks

Copy link
Collaborator

@QB3 QB3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a massive piece of work! Congrats @PABannier @Badr-MOUFAD !

My complain would be that I did not fully understood the need of calling multiple functions in the _validate function in solvers/base.py.

skglm/penalties/non_separable.py Show resolved Hide resolved
X : array, shape (n_samples, n_features)
Training data.

y : array, shape (n_samples,)
Copy link
Collaborator

@QB3 QB3 Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom_compatibility_check currently depends on the target y, but the target is never used in the check.
Should we remove y from this function, or do you see cases where it will be needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the validation of datafit/penalty depends on the data, for instance when X is sparse we should check that the datafit implements _sparse methods, IMO it is better to pass in both X, y

For now, we can settle for X only, but that means adding y later if we need it which would alter the API.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Badr-MOUFAD here, even if not needed at the moment it's not too hard to see cases where this would happen, and an API change will be painful


def _validate(self, X, y, datafit, penalty):
# execute both: attributes checks and `custom_compatibility_check`
self.custom_compatibility_check(X, y, datafit, penalty)
Copy link
Collaborator

@QB3 QB3 Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a conceptual point of view I am a bit confused.

The doc says custom_compatibility_check already checks the compatibility between the solver, the dataset and the penalty. Thus I do not understand the need to call multiple functions check_obj_solver_attr and check_obj_solver_attr? Does one check generic compatibility and the other check custom compatibility?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does one check generic compatibility and the other check custom compatibility?

Yes indeed. Generic checks are checks for datafit/penalty implement the required attributes in _datafit_required_attr/_penalty_required_attr.
On the other hand, custom_compatibility_check does other checks, for instance in GramCD it checks that datafit is instance of Quadratic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one of confusions for me is from the name, check_obj_solver_attr could be just check_attr; we pass solver just to get its name, though useful I don't think it is worth it.

skglm/solvers/fista.py Show resolved Hide resolved
skglm/experimental/pdcd_ws.py Outdated Show resolved Hide resolved
skglm/solvers/group_bcd.py Outdated Show resolved Hide resolved
skglm/solvers/group_prox_newton.py Outdated Show resolved Hide resolved
Badr-MOUFAD and others added 5 commits November 24, 2023 09:59
Co-authored-by: Quentin Bertrand <quentin.bertrand@mila.quebec>
Co-authored-by: Quentin Bertrand <quentin.bertrand@mila.quebec>
Co-authored-by: Quentin Bertrand <quentin.bertrand@mila.quebec>

Attributes
----------
_datafit_required_attr : list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also typo missing "that must BE" here and below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason

I think these two attributes should be read only.

While there is a way to make attributes read only, namely using the property decorator, I believe it adds two much complexity to the code and hence doesn’t serve our goal to make components implementation user-friendly.

I opted for the “start with underscore” naming convention to make variables private to signal to the user that these are attributes to not mess up with.

X : array, shape (n_samples, n_features)
Training data.

y : array, shape (n_samples,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Badr-MOUFAD here, even if not needed at the moment it's not too hard to see cases where this would happen, and an API change will be painful

missing_attrs = []
suffix = SPARSE_SUFFIX if support_sparse else ""

# if `attr` is a list check that at least one of them
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why "at least one of them" ?

if issparse(X):
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.")
# jit compile classes
datafit = compiled_clone(datafit_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is done only for this solver, is there a particular reason?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right.

I removed it.
One thing is that is might have some side effects (for user) as the compilation was done in _validate_init.


Notes
-----
For required attributes, if an attribute is given as a list of attributes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had trouble understanding this, an example help (in which case do we want to check that one of several attributes is present?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I 100% agree with you @mathurinm, I should have accompanied the docs with an example

In Fista solver, this

    _datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
    _penalty_required_attr = (("prox_1d", "prox_vec"),)

is interpreted as

  • datafit is required to have get_global_lipschitz and (gradient or gradient_scalar)
  • penalty is required to have prox_1d or prox_vec

This is the way I implemented check_obj_solver_attr function: whenever attributes are wrapped in parenthesis, it is interpreted as the “or” operator and a comma is interpreted the “and” operator.

@@ -27,6 +28,9 @@ class FISTA(BaseSolver):
https://epubs.siam.org/doi/10.1137/080716542
"""

_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
_penalty_required_attr = (("prox_1d", "prox_vec"),)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does FISTA need prox_1D, it is not used in the code below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in _prox_vec which take as argument the penalty (cf. https://github.com/PABannier/skglm/blob/e687bc2ecaacfe920b9aaad3e33e1f0cbdbac683/skglm/solvers/fista.py#L83)

The algorithm works if penalty has either of prox_1d or prox_vec.
(for reference #137 (comment))

def custom_compatibility_check(self, X, y, datafit):
if not isinstance(datafit, Quadratic):
raise AttributeError(
f"`GramCD` supports only `Quadratic` datafit, got {datafit}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very clean

@mathurinm
Copy link
Collaborator

Trying to revive this to release v0.4. Some comments for discussion @QB3 @Badr-MOUFAD

  • I feel strongly against having both solver.solve() and solver(). "There should be one-- and preferably only one --obvious way to do it".
    The way I see it, all solve methods should be renamed _solve, and BaseSolver should implement its solve that does the checks then call _solve. What do you think? This is non-breaking from the API point of view, just requires people implementing their own solver to adapt (I have no such example in mind ; we can implement a test that no Solver implements solve by checking that A.solve == BaseSolver.solve)
    BaseSolver.solve can have a run_checks method which is True by default, and disabled if one wants to be faster (any idea on the impact of the checks @Badr-MOUFAD ?)
  • I feel like checking for the sparse support is heavy at the moment and could be done in another PR, but that may just be me :)
  • I have mixed feelings for the name check_obj_solver_attr; see my comment, I'd go for check_attr and not pass solver, which is used only for its name.
  • custom_compatibility_check could be custom_checks (?)

@Badr-MOUFAD
Copy link
Collaborator

  • I’m +1 with having _solve method and adding run_checks argument to solve method.
    I feel it give us more freedom to standardize the behavior and make less verbose the solve method.
    I don’t think the check have a big overhead, though I didn’t check that in practice
  • The support of sparse data is covered in check_obj_solver_attr function, so I don’t think it hurts us to cover it in this PR
  • I have no strong opion about the names, I’m +1 with your proposed name @mathurinm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants