Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Liblinear convergence failure everywhere? #11536

Open
amueller opened this issue Jul 15, 2018 · 70 comments · May be fixed by #13317
Open

Liblinear convergence failure everywhere? #11536

amueller opened this issue Jul 15, 2018 · 70 comments · May be fixed by #13317

Comments

@amueller
Copy link
Member

Recently we made liblinear report convergence failures.
Now this is reported in lots of places. I expect our users will start to see it everywhere. Should we change something? It's weird if most uses of LinearSVC result in a "failure" now.

@amueller
Copy link
Member Author

Also, it's pretty weird that SVC has max_iter=-1 and LinearSVC has max_iter=1000. Is there a reason for that?

@agramfort
Copy link
Member

good question. Looking at liblinear code it appears we don't expose the different stopping criteria they have and we added a max_iter parameter they don't seem to have.

I have no idea why it was set to 1000. Was there any benchmark done?

@amueller
Copy link
Member Author

Not that I can remember...

@amueller amueller modified the milestones: 0.21, 0.20 Jul 16, 2018
@GaelVaroquaux
Copy link
Member

No strong opinion. Too many warnings means that users don't look at warnings.

The only thing that I can suggest is adding an option to control this behavior. I don't really like adding options (the danger is to have too many), but it seems here that there is no one-size-suits-all choice.

@amueller
Copy link
Member Author

we could also increase the tol? @ogrisel asked if we have this warning for logistic regression as well or if we ignore it there.

@ogrisel
Copy link
Member

ogrisel commented Jul 16, 2018

Does this issue also happen with the LogisticRegression class?

@GaelVaroquaux
Copy link
Member

I am -1 on increasing the tol: it will mean that many users will wait longer.

I think that there should be an option to control convergence warnings.

@amueller
Copy link
Member Author

increasing the tol meaning a larger tol. So if anything people will wait shorter.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Jul 17, 2018 via email

@samronsin
Copy link
Contributor

Working on this.

@samronsin
Copy link
Contributor

@ogrisel indeed LogisticRegression is also potentially affected.

@samronsin
Copy link
Contributor

As discussed with @agramfort I am a bit skeptical regarding bumping the tol as there are a lot of different defaults around:

  • LogisticRegression: max_iter=100, tol=0.0001
  • LinearSV{C,R}: max_iter=1000, tol=0.0001
  • SV{C,R}: max_iter=-1, tol=0.001

@amueller
Copy link
Member Author

This is only about liblinear, so where the tol is 0.0001 right now. So it would be making it more consistent. We should probably run some benchmarks, though?

@lesteve lesteve added this to Issues tagged in scikit-learn 0.20 Jul 17, 2018
@samronsin
Copy link
Contributor

Ah indeed, so maybe this is not as complex as I first thought.
Should the tol be 0.001 by default for all these liblinear calls ?
I agree regarding benchmarks !

@amueller
Copy link
Member Author

@samronsin yeah that would be good I think. This seems one of the last release blockers?

@amueller
Copy link
Member Author

btw the change that prompted all this is #10881 which basically was just a change in verbosity :-/

@amueller
Copy link
Member Author

btw using the default solver, tol is 0.1 (!) in liblinear by default. https://github.com/cjlin1/liblinear/blob/master/README

@jnothman
Copy link
Member

jnothman commented Jul 22, 2018 via email

@amueller
Copy link
Member Author

@jnothman this is mostly our wrapper that has the surprises, I think?

@jnothman
Copy link
Member

jnothman commented Aug 17, 2018 via email

@NicolasHug
Copy link
Member

The liblinear command line actually has various tolerance defaults, depending on the sub-solver that is used.

Do we want to use those? That would probably require switching the default to 'auto'.

@jnothman
Copy link
Member

jnothman commented Aug 20, 2018 via email

@amueller
Copy link
Member Author

@jnothman I think we need to benchmark but possibly?

@hermidalc
Copy link
Contributor

hermidalc commented Nov 27, 2019

I don't think we warn about scaling anywhere?
Also, this is a linear model. It should really converge without scaling.

SVC(kernel='linear') i.e. libsvm will also not converge and actually even worse hang with 100% CPU (since max_iter=-1) for many datasets I have if you don't scale data prior. So I'm in disagreement here... if you have features with wildly different scales to others fitting the optimal hyperplane at a reasonable tolerance will get difficult.

@hermidalc
Copy link
Contributor

Ok so this is pretty bad:

from sklearn.datasets import load_digits
from sklearn.svm import LinearSVC
digits = load_digits()
svm = LinearSVC(tol=1, max_iter=10000)
svm.fit(digits.data, digits.target)

If the data is not scaled, the dual solver (which is the default) will never converge on the digits dataset.

This can't really be solved with tol and max_iter, I think :(

Everywhere in the sklearn docs you specifically warn users that they need to scale data before use with many classifiers, etc. If one sets tol and max_iter to the correct defaults for liblinear L2-penalized dual solver then digits converges:

from sklearn.datasets import load_digits
from sklearn.svm import LinearSVC
digits = load_digits()
p = Pipeline([('s', StandardScaler()),
              ('c', LinearSVC(tol=1e-1, max_iter=1000))])
p.fit(digits.data, digits.target)

@smarie
Copy link
Contributor

smarie commented Nov 27, 2019

@hermidalc just to be sure, are you running Windows or an Unix-like ? Indeed there is a known issue with windows (#13511 )- but it happens only when the number of features or samples is very large, so I guess this is not the issue you're facing.

@hermidalc
Copy link
Contributor

@hermidalc just to be sure, are you running Windows or an Unix-like ? Indeed there is a known issue with windows (#13511 )- but it happens only when the number of features or samples is very large, so I guess this is not the issue you're facing.

Linux. The only issue I've faced is the LinearSVC convergence warnings because the default tol=1e-4 in sklearn is not what liblinear states should be the default 1e-1 for the default L2 dual solver. When you set tol=1e-1 and standardize your data prior (which is a must for SVM and many other classifiers) then these convergence issues go away.

@blacknred0
Copy link

Don't want to add more to the pot... but is the convergence warning also OS specific because it should behave differently on each OS? I assumed not, but based on my findings it seems to be. I've tested on macOS 10.15.2 (Catalina) vs Linux Fedora 30.

I ran the snap code from -> #11536 (comment) by @amueller and as you can see below for macOS that error does not show, but on linux it does show that error (same code!!!). I am not sure as the why? Is it because there might be different versions of liblinear on mac than linux?

Tested in both python major versions with old and recent libs and the results were the same.

  • macos -> py2.7 with libs numpy==1.16.3 scikit-learn==0.20.3 scipy==1.2.1
  • fedora -> py2.7 with libs numpy==1.16.3 scikit-learn==0.20.3 scipy==1.2.1
  • fedora -> py3.7 with libs numpy==1.17.4 scikit-learn==0.22 scipy==1.3.3

mac result

python test/svc.py
LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,
     intercept_scaling=1, loss='squared_hinge', max_iter=10000,
     multi_class='ovr', penalty='l2', random_state=None, tol=1, verbose=0)

fedora result

python /vagrant/test/svc.py
/home/vagrant/.local/lib/python2.7/site-packages/sklearn/svm/base.py:931: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
  "the number of iterations.", ConvergenceWarning)
LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,
     intercept_scaling=1, loss='squared_hinge', max_iter=10000,
     multi_class='ovr', penalty='l2', random_state=None, tol=1, verbose=0)

Any thoughts?

@amueller
Copy link
Member Author

amueller commented Dec 16, 2019

if you have features with wildly different scales to others fitting the optimal hyperplane at a reasonable tolerance will get difficult.

It depends a bit on what you mean by "difficult". You could probably do something like #15583 and solve the original optimization problem quite well. I'm not saying it's a good idea to not scale your data, I'm just saying it's totally possible to solve the optimization problem well despite the user giving you badly scaled data if your optimization algorithm is robust enough.

@hermidalc
Copy link
Contributor

hermidalc commented Dec 16, 2019

if you have features with wildly different scales to others fitting the optimal hyperplane at a reasonable tolerance will get difficult.

It depends a bit on what you mean by "difficult".

Sorry, what I was implying by difficult is relevant to this thread’s topic, meaning solving the optimization problem below a specific tolerance at or before a maximum number of iterations. Features that aren’t scaled make this harder to do with SVM unless, as you said, you use a very robust algorithm to solve the optimization problem. I thought LIBLINEAR uses coordinate descent isn’t this pretty robust?

@agramfort
Copy link
Member

agramfort commented Dec 17, 2019 via email

@amueller
Copy link
Member Author

Liblinear has several solvers. I think they use their own TRON (trust region newton) by default.

@amueller
Copy link
Member Author

amueller commented Dec 23, 2019

Also: we just changed our default away from liblinear...

The question which kind of problems are "hard" is likely to depend on the solver, I think, or how you formulate the problem.

@smarie
Copy link
Contributor

smarie commented Jan 16, 2020

Also: we just changed our default away from liblinear...

@amueller could you please point me to the corresponding issue/pr ? I did not see that in the master codebase. Thanks!

@amueller
Copy link
Member Author

@smarie
Copy link
Contributor

smarie commented Jan 17, 2020

Ah ok I mistakenly thought that this was about SVC. Thanks!

@hermidalc
Copy link
Contributor

if you have features with wildly different scales to others fitting the optimal hyperplane at a reasonable tolerance will get difficult.

It depends a bit on what you mean by "difficult". You could probably do something like #15583 and solve the original optimization problem quite well. I'm not saying it's a good idea to not scale your data, I'm just saying it's totally possible to solve the optimization problem well despite the user giving you badly scaled data if your optimization algorithm is robust enough.

To come back to this here’s some additional evidence challenging this belief when it comes to practical usage:

From the creators of LIBSVM and LIBLINEAR:
https://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf

Section 2.2 Scaling
Scaling before applying SVM is very important. Part 2 of Sarle’s Neural Networks FAQ Sarle (1997) explains the importance of this and most of considerations also ap- ply to SVM. The main advantage of scaling is to avoid attributes in greater numeric ranges dominating those in smaller numeric ranges. Another advantage is to avoid numerical difficulties during the calculation. Because kernel values usually depend on the inner products of feature vectors, e.g. the linear kernel and the polynomial ker- nel, large attribute values might cause numerical problems. We recommend linearly scaling each attribute to the range [−1, +1] or [0, 1].
Of course we have to use the same method to scale both training and testing data. For example, suppose that we scaled the first attribute of training data from [−10, +10] to [−1, +1]. If the first attribute of testing data lies in the range [−11, +8], we must scale the testing data to [−1.1, +0.8]. See Appendix B for some real examples.

@amueller
Copy link
Member Author

@hermidalc I observed it to be a bit more stable than lbfgs in some settings I tried, see the preconditioning issue & pr.

I'm not entirely sure how we can make the user experience better here :-/ I've seen plenty of convergence issues even after scaling, but I haven't had the time to compose them.

@adrinjalali
Copy link
Member

I'm trying to remove issues which have been around for more than 2 releases from the milestones. But this one seems to be pressing and you really care about it @amueller . Leaving it in the milestone for 0.24, but we really should be better at following up on these.

@adrinjalali adrinjalali modified the milestones: 0.23, 0.24 Apr 20, 2020
@hermidalc
Copy link
Contributor

@hermidalc I observed it to be a bit more stable than lbfgs in some settings I tried, see the preconditioning issue & pr.

I'm not entirely sure how we can make the user experience better here :-/ I've seen plenty of convergence issues even after scaling, but I haven't had the time to compose them.

I have to say @amueller I do agree with you more now. With various high-dimensional datasets I've been working with these last few months, I've been seeing frequent convergence issues with LinearSVC after properly transforming and scaling the data beforehand, even after setting the tol=1e-1 which is what LIBLINEAR has and setting max_iter=10000 or greater. The optimization algorithm appears to particularly have convergence issues when performing model selection over a range of C when higher values of like 1e2 or greater

The exact same workflows with SVC(kernel='linear') generally do not have any convergence problems. While the scores from both models are usually somehwhat similar, even with LinearSVC not being able to converge, they aren't the same and for some datasets it's really different. So for L2-penalized linear classification where I previously used LinearSVC I'm now going back to SVC and SGDClassifier.

The problem is that only LinearSVC can solve penalty='l1' and dual=False problems for e.g. SelectFromModel feature selection, so it would be important for scikit-learn to fix the issue with the implementation. Possibly SGDClassifier with penalty='l1' can be used instead?

Maybe the latest LIBLINEAR code has updates/fixes that have corrected what is the underlying problem? Looks like the main liblinear code in sklearn is from back in 2014.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
scikit-learn 0.20
  
Issues tagged