Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiLogisticRegression panics on normalized data. #334

Open
KNOXDEV opened this issue Mar 18, 2024 · 0 comments
Open

MultiLogisticRegression panics on normalized data. #334

KNOXDEV opened this issue Mar 18, 2024 · 0 comments

Comments

@KNOXDEV
Copy link

KNOXDEV commented Mar 18, 2024

Firstly, let me say I'm very new to data science / ML so my understanding / terminology may be wrong. Please bare with me, thanks in advance.

I'm using a relatively small dataset (769 features over 6893 samples) and a small number of categories (8). All of my weights are normalized between [0,1], though most are zero. I'm using the default configuration:

let model = MultiLogisticRegression::default().fit(&dataset).unwrap();
thread 'main' panicked at src/train.rs:121:22:
called `Result::unwrap()` on an `Err` value: ArgMinError(Condition violated: "`MoreThuenteLineSearch`: Search direction must be a descent direction.")

I've observed that if I set alpha to a larger value like 10, I get a result without panicking. I've also noticed that if I limit the number of iterations to a very small number, say, 20, I also get a good result without panicking. Therefore, I think the culprit is overfitting / divergence (uncertain of the proper terminology here).

I will say that I was able to use smartcore's multinomial logistic regression routines while setting alpha = 0 without issue. Notably, they use a Backtracking line search implementation instead of More-Thuente. I don't know if that's relevant or not.

I think this situation is related to this comment on the original MultiLogisticRegression PR regarding divergence. If this is something that can be addressed by linfa by using a different line search algorithm with different numerical requirements, great. If changing the default line search algorithm is undesirable, then at least letting users configure the algorithm used would be greatly appreciated. Most of all however, I would suggest printing a significantly more helpful error message when this divergence happens. If linfa could catch the error returned by argmin and translate it to something along the lines of "your dataset diverged, please increase alpha or reduce the number of iterations", I imagine you would save a ton of developer troubleshooting hours.

Thanks again, and please let me know if there's any other details I should provide here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant