Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH multiclass/multinomial newton cholesky for LogisticRegression #28840

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d065683
ENH add full multinomial hessian matrix
lorentzenchr Apr 14, 2024
0b39369
TST improve test_loss_dtype of common losses
lorentzenchr Apr 15, 2024
2e44c91
ENH improve newton solver verbosity
lorentzenchr Apr 15, 2024
94dfa17
FIX linear loss gradient_hessian contiguity
lorentzenchr Apr 15, 2024
ee9b306
ENH extend newton-cholesky for multinomial loss
lorentzenchr Apr 15, 2024
e4eef4d
ENH multinomial newton-cholesky for LogisticRegression
lorentzenchr Apr 15, 2024
92adb03
DOC add whatsnew
lorentzenchr Apr 15, 2024
4062d8a
Merge branch 'main' into multiclass_newton_cholesky
lorentzenchr Apr 15, 2024
2963f32
TST better test e.g. for fit_intercept=True
lorentzenchr Apr 15, 2024
c8460db
DOC add PR number to whatsnew
lorentzenchr Apr 15, 2024
37c9f64
TST fix test_logistic_regression_solvers_multiclass
lorentzenchr Apr 16, 2024
666a638
ENH set coefficients of reference class to zero
lorentzenchr Apr 18, 2024
1de85b7
TST fix and extend test_multinomial_identifiability_on_iris
lorentzenchr Apr 18, 2024
604165e
TST improve test coverage in linear_loss
lorentzenchr Apr 19, 2024
1108d45
CLN comments from review suggestions
lorentzenchr Apr 25, 2024
5f79028
CLN better comment fot sandwich_dot
lorentzenchr Apr 25, 2024
bc60467
Merge branch 'main' into multiclass_newton_cholesky
lorentzenchr May 8, 2024
425cc09
TST 2 more filterwarnings for Deprecation
lorentzenchr May 10, 2024
1a8d4f9
trigger CI
lorentzenchr May 10, 2024
96d6d6b
Merge branch 'main' into multiclass_newton_cholesky
lorentzenchr May 18, 2024
88457e4
DOC move to 1.6 changelog
lorentzenchr May 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.5.rst
Expand Up @@ -264,6 +264,12 @@ Changelog
:mod:`sklearn.linear_model`
...........................

- |Enhancement| The `solver="newton-cholesky"` in
:class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` is extended to support the full
multinomial loss in a multiclass setting.
:pr:`28840` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Enhancement| Solver `"newton-cg"` in :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now emits information when `verbose` is
set to positive values.
Expand Down
18 changes: 13 additions & 5 deletions sklearn/_loss/tests/test_loss.py
Expand Up @@ -419,52 +419,60 @@ def test_loss_dtype(
if sample_weight is not None:
sample_weight = create_memmap_backed_data(sample_weight)

loss.loss(
l = loss.loss(
y_true=y_true,
raw_prediction=raw_prediction,
sample_weight=sample_weight,
loss_out=out1,
n_threads=n_threads,
)
loss.gradient(
assert l is out1 if out1 is not None else True
g = loss.gradient(
y_true=y_true,
raw_prediction=raw_prediction,
sample_weight=sample_weight,
gradient_out=out2,
n_threads=n_threads,
)
loss.loss_gradient(
assert g is out2 if out2 is not None else True
l, g = loss.loss_gradient(
y_true=y_true,
raw_prediction=raw_prediction,
sample_weight=sample_weight,
loss_out=out1,
gradient_out=out2,
n_threads=n_threads,
)
assert l is out1 if out1 is not None else True
assert g is out2 if out2 is not None else True
if out1 is not None and loss.is_multiclass:
out1 = np.empty_like(raw_prediction, dtype=dtype_out)
loss.gradient_hessian(
g, h = loss.gradient_hessian(
y_true=y_true,
raw_prediction=raw_prediction,
sample_weight=sample_weight,
gradient_out=out1,
hessian_out=out2,
n_threads=n_threads,
)
assert g is out1 if out1 is not None else True
assert h is out2 if out2 is not None else True
loss(y_true=y_true, raw_prediction=raw_prediction, sample_weight=sample_weight)
loss.fit_intercept_only(y_true=y_true, sample_weight=sample_weight)
loss.constant_to_optimal_zero(y_true=y_true, sample_weight=sample_weight)
if hasattr(loss, "predict_proba"):
loss.predict_proba(raw_prediction=raw_prediction)
if hasattr(loss, "gradient_proba"):
loss.gradient_proba(
g, p = loss.gradient_proba(
y_true=y_true,
raw_prediction=raw_prediction,
sample_weight=sample_weight,
gradient_out=out1,
proba_out=out2,
n_threads=n_threads,
)
assert g is out1 if out1 is not None else True
assert p is out2 if out2 is not None else True


@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
Expand Down
116 changes: 104 additions & 12 deletions sklearn/linear_model/_glm/_newton_solver.py
Expand Up @@ -298,6 +298,11 @@ def line_search(self, X, y, sample_weight):
return

self.raw_prediction = raw
if is_verbose:
print(
f" line search successful after {i+1} iterations with "
f"loss={self.loss_value}."
)

def check_convergence(self, X, y, sample_weight):
"""Check for convergence.
Expand All @@ -310,24 +315,27 @@ def check_convergence(self, X, y, sample_weight):
# convergence criterion because even a large step could have brought us close
# to the true minimum.
# coef_step = self.coef - self.coef_old
# check = np.max(np.abs(coef_step) / np.maximum(1, np.abs(self.coef_old)))
# change = np.max(np.abs(coef_step) / np.maximum(1, np.abs(self.coef_old)))
# check = change <= tol

# 1. Criterion: maximum |gradient| <= tol
# The gradient was already updated in line_search()
check = np.max(np.abs(self.gradient))
g_max_abs = np.max(np.abs(self.gradient))
check = g_max_abs <= self.tol
if self.verbose:
print(f" 1. max |gradient| {check} <= {self.tol}")
if check > self.tol:
print(f" 1. max |gradient| {g_max_abs} <= {self.tol} {check}")
if not check:
return

# 2. Criterion: For Newton decrement d, check 1/2 * d^2 <= tol
# d = sqrt(grad @ hessian^-1 @ grad)
# = sqrt(coef_newton @ hessian @ coef_newton)
# See Boyd, Vanderberghe (2009) "Convex Optimization" Chapter 9.5.1.
d2 = self.coef_newton @ self.hessian @ self.coef_newton
check = 0.5 * d2 <= self.tol
if self.verbose:
print(f" 2. Newton decrement {0.5 * d2} <= {self.tol}")
if 0.5 * d2 > self.tol:
print(f" 2. Newton decrement {0.5 * d2} <= {self.tol} {check}")
if not check:
return

if self.verbose:
Expand Down Expand Up @@ -442,11 +450,23 @@ class NewtonCholeskySolver(NewtonSolver):

def setup(self, X, y, sample_weight):
super().setup(X=X, y=y, sample_weight=sample_weight)
n_dof = X.shape[1]
if self.linear_loss.fit_intercept:
n_dof += 1
if self.linear_loss.base_loss.is_multiclass:
# Easier with ravelled arrays, e.g., for scipy.linalg.solve.
# As with LinearModelLoss, we always are contiguous in n_classes.
self.coef = self.coef.ravel(order="F")
# Note that the computation of gradient in LinearModelLoss follows the shape of
# coef.
self.gradient = np.empty_like(self.coef)
self.hessian = np.empty_like(self.coef, shape=(n_dof, n_dof))
# But the hessian is always 2d.
n = self.coef.size
self.hessian = np.empty_like(self.coef, shape=(n, n))
# To help case distinctions.
self.is_multinomial_with_intercept = (
self.linear_loss.base_loss.is_multiclass and self.linear_loss.fit_intercept
)
self.is_multinomial_no_penalty = (
self.linear_loss.base_loss.is_multiclass and self.l2_reg_strength == 0
)

def update_gradient_hessian(self, X, y, sample_weight):
_, _, self.hessian_warning = self.linear_loss.gradient_hessian(
Expand Down Expand Up @@ -479,12 +499,70 @@ def inner_solve(self, X, y, sample_weight):
self.use_fallback_lbfgs_solve = True
return

# Note: The following case distinction could also be shifted to the
# implementation of HalfMultinomialLoss instead of here within the solver.
if self.is_multinomial_no_penalty:
# The multinomial loss is overparametrized for each unpenalized feature, so
# at least the intercepts. This can be seen by noting that predicted
# probabilities are invariant under shifting all coefficients of a single
# feature j for all classes by the same amount c:
# coef[k, :] -> coef[k, :] + c => proba stays the same
# where we have assumned coef.shape = (n_classes, n_features).
# Therefore, also the loss (-log-likelihood), gradient and hessian stay the
# same, see
# Noah Simon and Jerome Friedman and Trevor Hastie. (2013) "A Blockwise
# Descent Algorithm for Group-penalized Multiresponse and Multinomial
# Regression". https://doi.org/10.48550/arXiv.1311.6529
#
# We choose the standard approach and set all the coefficients of the last
# class to zero, for all features including the intercept.
n_classes = self.linear_loss.base_loss.n_classes
n_dof = self.coef.size // n_classes # degree of freedom per class
n = self.coef.size - n_dof # effective size
self.coef[n_classes - 1 :: n_classes] = 0
self.gradient[n_classes - 1 :: n_classes] = 0
self.hessian[n_classes - 1 :: n_classes, :] = 0
self.hessian[:, n_classes - 1 :: n_classes] = 0
# We also need the reduced variants of gradient and hessian where the
# entries set to zero are removed. For 2 features and 3 classes with
# arbitrary values, "x" means removed:
# gradient = [0, 1, x, 3, 4, x]
#
# hessian = [0, 1, x, 3, 4, x]
# [1, 7, x, 9, 10, x]
# [x, x, x, x, x, x]
# [3, 9, x, 21, 22, x]
# [4, 10, x, 22, 28, x]
# [x, x, x, x, x, x]
# The following slicing triggers copies of gradient and hessian.
gradient = self.gradient.reshape(-1, n_classes)[:, :-1].flatten()
hessian = self.hessian.reshape(n_dof, n_classes, n_dof, n_classes)[
:, :-1, :, :-1
].reshape(n, n)
elif self.is_multinomial_with_intercept:
# Here, only intercepts are unpenalized. We again choose the last class and
# set its intercept to zero.
self.coef[-1] = 0
self.gradient[-1] = 0
self.hessian[-1, :] = 0
self.hessian[:, -1] = 0
gradient, hessian = self.gradient[:-1], self.hessian[:-1, :-1]
else:
gradient, hessian = self.gradient, self.hessian

try:
with warnings.catch_warnings():
warnings.simplefilter("error", scipy.linalg.LinAlgWarning)
self.coef_newton = scipy.linalg.solve(
self.hessian, -self.gradient, check_finite=False, assume_a="sym"
hessian, -gradient, check_finite=False, assume_a="sym"
)
if self.is_multinomial_no_penalty:
self.coef_newton = np.c_[
self.coef_newton.reshape(n_dof, n_classes - 1), np.zeros(n_dof)
].reshape(-1)
assert self.coef_newton.flags.f_contiguous
elif self.is_multinomial_with_intercept:
self.coef_newton = np.r_[self.coef_newton, 0]
self.gradient_times_newton = self.gradient @ self.coef_newton
if self.gradient_times_newton > 0:
if self.verbose:
Expand All @@ -498,7 +576,7 @@ def inner_solve(self, X, y, sample_weight):
warnings.warn(
f"The inner solver of {self.__class__.__name__} stumbled upon a "
"singular or very ill-conditioned Hessian matrix at iteration "
f"#{self.iteration}. It will now resort to lbfgs instead.\n"
f"{self.iteration}. It will now resort to lbfgs instead.\n"
"Further options are to use another solver or to avoid such situation "
"in the first place. Possible remedies are removing collinear features"
" of X or increasing the penalization strengths.\n"
Expand All @@ -522,3 +600,17 @@ def inner_solve(self, X, y, sample_weight):
)
self.use_fallback_lbfgs_solve = True
return

def finalize(self, X, y, sample_weight):
if self.is_multinomial_no_penalty:
# Our convention is usually the symmetric parametrization where
# sum(coef[classes, features], axis=0) = 0.
# We convert now to this convention. Note that it does not change
# the predicted probabilities.
n_classes = self.linear_loss.base_loss.n_classes
self.coef = self.coef.reshape(n_classes, -1, order="F")
self.coef -= np.mean(self.coef, axis=0)
elif self.is_multinomial_with_intercept:
# Only the intercept needs an update to the symmetric parametrization.
n_classes = self.linear_loss.base_loss.n_classes
self.coef[-n_classes:] -= np.mean(self.coef[-n_classes:])