Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add array-api support to metrics.confusion_matrix #28867

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
81 changes: 48 additions & 33 deletions sklearn/metrics/_classification.py
Expand Up @@ -40,8 +40,13 @@
)
from ..utils._array_api import (
_average,
_intersect1d,
_is_numpy_namespace,
_nan_to_num,
_union1d,
get_namespace,
get_namespace_and_device,
size,
)
from ..utils._param_validation import (
Hidden,
Expand Down Expand Up @@ -111,10 +116,10 @@ def _check_targets(y_true, y_pred):
y_type = y_type.pop()

# No metrics support "multiclass-multioutput" format
if y_type not in ["binary", "multiclass", "multilabel-indicator"]:
if y_type not in {"binary", "multiclass", "multilabel-indicator"}:
raise ValueError("{0} is not supported".format(y_type))

if y_type in ["binary", "multiclass"]:
if y_type in {"binary", "multiclass"}:
xp, _ = get_namespace(y_true, y_pred)
y_true = column_or_1d(y_true)
y_pred = column_or_1d(y_pred)
Expand All @@ -128,8 +133,8 @@ def _check_targets(y_true, y_pred):
# strings. So we raise a meaningful error
raise TypeError(
"Labels in y_true and y_pred should be of the same type. "
f"Got y_true={xp.unique(y_true)} and "
f"y_pred={xp.unique(y_pred)}. Make sure that the "
f"Got y_true={xp.unique_values(y_true)} and "
f"y_pred={xp.unique_values(y_pred)}. Make sure that the "
"predictions provided by the classifier coincides with "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"predictions provided by the classifier coincides with "
"predictions provided by the classifier coincide with "

unrelated grammar fix (I think)

"the true labels."
) from e
Expand Down Expand Up @@ -322,70 +327,80 @@ def confusion_matrix(
(0, 2, 1, 1)
"""
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
xp, _, device = get_namespace_and_device(y_true, y_pred)
if y_type not in ("binary", "multiclass"):
raise ValueError("%s is not supported" % y_type)

if labels is None:
labels = unique_labels(y_true, y_pred)
else:
labels = np.asarray(labels)
n_labels = labels.size
n_labels = size(labels)
if n_labels == 0:
raise ValueError("'labels' should contains at least one label.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("'labels' should contains at least one label.")
raise ValueError("'labels' should contain at least one label.")

unrelated typo fix

elif y_true.size == 0:
return np.zeros((n_labels, n_labels), dtype=int)
elif len(np.intersect1d(y_true, labels)) == 0:

if size(y_true) == 0:
return xp.zeros((n_labels, n_labels), dtype=xp.int32, device=device)

if size(_intersect1d(y_true, labels, xp=xp)) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if labels is a Python list of strings and y_true is, say, a torch array?

As a user I'd probably provide the labels as a Python list/tuple, mostly because it is convenient and not performance critical.

Is there a downside to being helpful to the callers and allowing list/tuple here?

raise ValueError("At least one label specified must be in y_true")

if sample_weight is None:
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64, device=device)
else:
sample_weight = np.asarray(sample_weight)
sample_weight = xp.asarray(sample_weight, device=device)

check_consistent_length(y_true, y_pred, sample_weight)

n_labels = labels.size
n_labels = size(labels)
# If labels are not consecutive integers starting from zero, then
# y_true and y_pred must be converted into index form
need_index_conversion = not (
labels.dtype.kind in {"i", "u", "b"}
and np.all(labels == np.arange(n_labels))
and y_true.min() >= 0
and y_pred.min() >= 0
xp.isdtype(labels.dtype, ("bool", "integral"))
and xp.all(labels == xp.arange(n_labels, dtype=labels.dtype, device=device))
and xp.min(y_true) >= 0
and xp.min(y_pred) >= 0
)
if need_index_conversion:
label_to_ind = {y: x for x, y in enumerate(labels)}
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
mapped_preds = [label_to_ind.get(x, n_labels + 1) for x in y_pred]
mapped_truth = [label_to_ind.get(x, n_labels + 1) for x in y_true]
y_pred = xp.asarray(mapped_preds, device=device)
y_true = xp.asarray(mapped_truth, device=device)

# intersect y_pred, y_true with labels, eliminate items not in labels
ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
if not np.all(ind):
ind = xp.logical_and(y_pred < n_labels, y_true < n_labels)
if not xp.all(ind):
y_pred = y_pred[ind]
y_true = y_true[ind]
# also eliminate weights of eliminated items
sample_weight = sample_weight[ind]

# Choose the accumulator dtype to always have high precision
if sample_weight.dtype.kind in {"i", "u", "b"}:
dtype = np.int64
if xp.isdtype(sample_weight.dtype, ("bool", "integral")):
dtype = xp.int64
else:
dtype = np.float64

cm = coo_matrix(
(sample_weight, (y_true, y_pred)),
shape=(n_labels, n_labels),
dtype=dtype,
).toarray()
dtype = xp.float64

if _is_numpy_namespace(xp):
cm = coo_matrix(
(sample_weight, (y_true, y_pred)),
shape=(n_labels, n_labels),
dtype=dtype,
).toarray()
else:
# Harder and slower way.
cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=device)
for r, c, weight in zip(y_true, y_pred, sample_weight):
cm[r, c] += weight

with np.errstate(all="ignore"):
if normalize == "true":
cm = cm / cm.sum(axis=1, keepdims=True)
cm = cm / xp.sum(cm, axis=1, keepdims=True)
elif normalize == "pred":
cm = cm / cm.sum(axis=0, keepdims=True)
cm = cm / xp.sum(cm, axis=0, keepdims=True)
elif normalize == "all":
cm = cm / cm.sum()
cm = np.nan_to_num(cm)
cm = cm / xp.sum(cm)
cm = _nan_to_num(cm, xp=xp)

if cm.shape == (1, 1):
warnings.warn(
Expand Down
35 changes: 27 additions & 8 deletions sklearn/metrics/tests/test_classification.py
Expand Up @@ -9,7 +9,7 @@
from scipy.spatial.distance import hamming as sp_hamming
from scipy.stats import bernoulli

from sklearn import datasets, svm
from sklearn import config_context, datasets, svm
from sklearn.datasets import make_multilabel_classification
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (
Expand Down Expand Up @@ -40,17 +40,25 @@
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._testing import (
from sklearn.utils.extmath import _nanaverage
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
from sklearn.utils.validation import check_random_state

from ...utils._array_api import (
_convert_to_numpy,
_is_numpy_namespace,
get_namespace_and_device,
yield_namespace_device_combinations,
)
from ...utils._testing import (
_array_api_for_tests,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
assert_no_warnings,
ignore_warnings,
)
from sklearn.utils.extmath import _nanaverage
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
from sklearn.utils.validation import check_random_state

###############################################################################
# Utilities for testing
Expand Down Expand Up @@ -480,12 +488,21 @@ def test_precision_recall_f_unused_pos_label():
)


def test_confusion_matrix_binary():
@pytest.mark.parametrize("namespace, device", yield_namespace_device_combinations())
def test_confusion_matrix_binary(namespace, device):
xp = _array_api_for_tests(namespace, device)

# Test confusion matrix - binary classification case
y_true, y_pred, _ = make_prediction(binary=True)

y_true = xp.asarray(y_true, device=device)
y_pred = xp.asarray(y_pred, device=device)

def test(y_true, y_pred):
cm = confusion_matrix(y_true, y_pred)
with config_context(array_api_dispatch=True):
cm = confusion_matrix(y_true, y_pred)
_ = get_namespace_and_device(y_true, y_pred, cm)
cm = _convert_to_numpy(cm, xp=xp)
assert_array_equal(cm, [[22, 3], [8, 17]])

tp, fp, fn, tn = cm.flatten()
Expand All @@ -498,7 +515,9 @@ def test(y_true, y_pred):
assert_array_almost_equal(mcc, 0.57, decimal=2)

test(y_true, y_pred)
test([str(y) for y in y_true], [str(y) for y in y_pred])

if _is_numpy_namespace(xp):
test([str(y) for y in y_true], [str(y) for y in y_pred])


def test_multilabel_confusion_matrix_binary():
Expand Down