-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH - check datafit + penalty
compatibility with solver
#137
base: main
Are you sure you want to change the base?
Conversation
With this PR, the errors are more verbose: In [1]: from skglm.estimators import GeneralizedLinearEstimator
from skglm.penalties import L0_5
from skglm.datafits import Quadratic, Logistic
from skglm.solvers import ProxNewton, AndersonCD
import numpy as np
In [2]: X = np.random.normal(0, 1, (30, 50))
y = np.random.normal(0, 1, (30,))
In [3]: clf = GeneralizedLinearEstimator(Quadratic(), L0_5(1.), ProxNewton())
In [4]: clf.fit(X, y)
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 clf.fit(X, y)
File ~/Documents/skglm/skglm/estimators.py:241, in GeneralizedLinearEstimator.fit(self, X, y)
238 self.datafit = self.datafit if self.datafit else Quadratic()
239 self.solver = self.solver if self.solver else AndersonCD()
--> 241 return _glm_fit(X, y, self, self.datafit, self.penalty, self.solver)
File ~/Documents/skglm/skglm/estimators.py:29, in _glm_fit(X, y, model, datafit, penalty, solver)
27 is_classif = isinstance(datafit, (Logistic, QuadraticSVC))
28 fit_intercept = solver.fit_intercept
---> 29 validate_solver(solver, datafit, penalty)
31 if is_classif:
32 check_classification_targets(y)
File ~/Documents/skglm/skglm/utils/dispatcher.py:21, in validate_solver(solver, datafit, penalty)
6 """Ensure the solver is suited for the `datafit` + `penalty` problem.
7
8 Parameters
(...)
17 Penalty.
18 """
19 if (isinstance(solver, ProxNewton)
20 and not set(("raw_grad", "raw_hessian")) <= set(dir(datafit))):
---> 21 raise Exception(
22 f"ProwNewton cannot optimize {datafit.__class__.__name__}, since `raw_grad`"
23 " and `raw_hessian` are not implemented.")
24 if ("ws_strategy" in dir(solver) and solver.ws_strategy == "subdiff"
25 and isinstance(penalty, (L0_5, L2_3))):
26 raise Exception(
27 "ws_strategy=`subdiff` is not available for Lp penalties (p < 1). "
28 "Set ws_strategy to `fixpoint`.")
Exception: ProwNewton cannot optimize Quadratic, since `raw_grad` and `raw_hessian` are not implemented. |
_glm_fit
Looks nice @PABannier, this will definitely improve UX! From an API point of view, shouldn't this check be delegated to each solver? This way we don't have one big function, but Such functions could also take care of the initialization (e.g. stepsize computation) which is done on a solver basis. WDYT? |
@mathurinm Yes I think it's cleaner, currently refining the POC. |
This would be a nice addition if we can ship it in the 0.3 release @Badr-MOUFAD , given that we added a few datafits, penalties and solvers ! |
@Badr-MOUFAD the issue popped up in #188, do you have time to take this over ? A simple check, at the beginning of each solver, that the datafit and penalty are supported (eg AndersonCD does not support Gamma datafit) |
Sure, I will resume this PR. |
…into solver_dispatcher
_glm_fit
datafit + penalty
compatibility with solver
Requires #191 to be implemented fit to allow for better checks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What a massive piece of work! Congrats @PABannier @Badr-MOUFAD !
My complain would be that I did not fully understood the need of calling multiple functions in the _validate
function in solvers/base.py
.
X : array, shape (n_samples, n_features) | ||
Training data. | ||
|
||
y : array, shape (n_samples,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
custom_compatibility_check
currently depends on the target y
, but the target is never used in the check.
Should we remove y
from this function, or do you see cases where it will be needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the validation of datafit/penalty depends on the data, for instance when X
is sparse we should check that the datafit implements _sparse
methods, IMO it is better to pass in both X, y
For now, we can settle for X
only, but that means adding y
later if we need it which would alter the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @Badr-MOUFAD here, even if not needed at the moment it's not too hard to see cases where this would happen, and an API change will be painful
|
||
def _validate(self, X, y, datafit, penalty): | ||
# execute both: attributes checks and `custom_compatibility_check` | ||
self.custom_compatibility_check(X, y, datafit, penalty) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a conceptual point of view I am a bit confused.
The doc says custom_compatibility_check
already checks the compatibility between the solver, the dataset and the penalty. Thus I do not understand the need to call multiple functions check_obj_solver_attr
and check_obj_solver_attr
? Does one check generic compatibility and the other check custom compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does one check generic compatibility and the other check custom compatibility?
Yes indeed. Generic checks are checks for datafit/penalty implement the required attributes in _datafit_required_attr
/_penalty_required_attr
.
On the other hand, custom_compatibility_check
does other checks, for instance in GramCD it checks that datafit is instance of Quadratic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one of confusions for me is from the name, check_obj_solver_attr
could be just check_attr
; we pass solver
just to get its name, though useful I don't think it is worth it.
Co-authored-by: Quentin Bertrand <quentin.bertrand@mila.quebec>
Co-authored-by: Quentin Bertrand <quentin.bertrand@mila.quebec>
Co-authored-by: Quentin Bertrand <quentin.bertrand@mila.quebec>
|
||
Attributes | ||
---------- | ||
_datafit_required_attr : list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also typo missing "that must BE" here and below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make these public @BadrMOUFAD ? Maybe I'm missing a specific reason
I think these two attributes should be read only.
While there is a way to make attributes read only, namely using the property decorator, I believe it adds two much complexity to the code and hence doesn’t serve our goal to make components implementation user-friendly.
I opted for the “start with underscore” naming convention to make variables private to signal to the user that these are attributes to not mess up with.
X : array, shape (n_samples, n_features) | ||
Training data. | ||
|
||
y : array, shape (n_samples,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @Badr-MOUFAD here, even if not needed at the moment it's not too hard to see cases where this would happen, and an API change will be painful
missing_attrs = [] | ||
suffix = SPARSE_SUFFIX if support_sparse else "" | ||
|
||
# if `attr` is a list check that at least one of them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why "at least one of them" ?
skglm/experimental/pdcd_ws.py
Outdated
if issparse(X): | ||
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.") | ||
# jit compile classes | ||
datafit = compiled_clone(datafit_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is done only for this solver, is there a particular reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are right.
I removed it.
One thing is that is might have some side effects (for user) as the compilation was done in _validate_init
.
|
||
Notes | ||
----- | ||
For required attributes, if an attribute is given as a list of attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had trouble understanding this, an example help (in which case do we want to check that one of several attributes is present?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I 100% agree with you @mathurinm, I should have accompanied the docs with an example
In Fista solver, this
_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
_penalty_required_attr = (("prox_1d", "prox_vec"),)
is interpreted as
- datafit is required to have
get_global_lipschitz
and (gradient
orgradient_scalar
) - penalty is required to have
prox_1d
orprox_vec
This is the way I implemented check_obj_solver_attr
function: whenever attributes are wrapped in parenthesis, it is interpreted as the “or” operator and a comma is interpreted the “and” operator.
@@ -27,6 +28,9 @@ class FISTA(BaseSolver): | |||
https://epubs.siam.org/doi/10.1137/080716542 | |||
""" | |||
|
|||
_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar")) | |||
_penalty_required_attr = (("prox_1d", "prox_vec"),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does FISTA need prox_1D, it is not used in the code below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used in _prox_vec
which take as argument the penalty (cf. https://github.com/PABannier/skglm/blob/e687bc2ecaacfe920b9aaad3e33e1f0cbdbac683/skglm/solvers/fista.py#L83)
The algorithm works if penalty has either of prox_1d
or prox_vec
.
(for reference #137 (comment))
def custom_compatibility_check(self, X, y, datafit): | ||
if not isinstance(datafit, Quadratic): | ||
raise AttributeError( | ||
f"`GramCD` supports only `Quadratic` datafit, got {datafit}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very clean
Trying to revive this to release
|
|
A quick proof-of-concept of a function that checks if the combination
(solver, datafit, penalty)
is supported. Currently we have some edge cases where one can passProxNewton
solver withL0_5
penalty without any error being raised.Pros of this design: the validation rules are centralized and validating a 3-uple is a one-liner in
glm_fit
.Cons: we have to update the rules as we enhance the capabilities of the solver.
All in all, I think it is very valuable to have more verbose errors when fitting estimators (e.g. Ali Rahimi initially passed a combination Quadratic, L2_3, ProxNewton which cannot be optimized at the moment of writing).
Closes #101
Closes #90
Closes #109