-
-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add array-api support to metrics.confusion_matrix #28867
base: main
Are you sure you want to change the base?
Changes from all commits
9166240
85163a5
ae9d207
5cfb85f
e258ba1
864ffea
aa58127
f5565f3
2ea2722
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -40,8 +40,13 @@ | |||||
) | ||||||
from ..utils._array_api import ( | ||||||
_average, | ||||||
_intersect1d, | ||||||
_is_numpy_namespace, | ||||||
_nan_to_num, | ||||||
_union1d, | ||||||
get_namespace, | ||||||
get_namespace_and_device, | ||||||
size, | ||||||
) | ||||||
from ..utils._param_validation import ( | ||||||
Hidden, | ||||||
|
@@ -111,10 +116,10 @@ def _check_targets(y_true, y_pred): | |||||
y_type = y_type.pop() | ||||||
|
||||||
# No metrics support "multiclass-multioutput" format | ||||||
if y_type not in ["binary", "multiclass", "multilabel-indicator"]: | ||||||
if y_type not in {"binary", "multiclass", "multilabel-indicator"}: | ||||||
raise ValueError("{0} is not supported".format(y_type)) | ||||||
|
||||||
if y_type in ["binary", "multiclass"]: | ||||||
if y_type in {"binary", "multiclass"}: | ||||||
xp, _ = get_namespace(y_true, y_pred) | ||||||
y_true = column_or_1d(y_true) | ||||||
y_pred = column_or_1d(y_pred) | ||||||
|
@@ -128,8 +133,8 @@ def _check_targets(y_true, y_pred): | |||||
# strings. So we raise a meaningful error | ||||||
raise TypeError( | ||||||
"Labels in y_true and y_pred should be of the same type. " | ||||||
f"Got y_true={xp.unique(y_true)} and " | ||||||
f"y_pred={xp.unique(y_pred)}. Make sure that the " | ||||||
f"Got y_true={xp.unique_values(y_true)} and " | ||||||
f"y_pred={xp.unique_values(y_pred)}. Make sure that the " | ||||||
"predictions provided by the classifier coincides with " | ||||||
"the true labels." | ||||||
) from e | ||||||
|
@@ -322,70 +327,80 @@ def confusion_matrix( | |||||
(0, 2, 1, 1) | ||||||
""" | ||||||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||||||
xp, _, device = get_namespace_and_device(y_true, y_pred) | ||||||
if y_type not in ("binary", "multiclass"): | ||||||
raise ValueError("%s is not supported" % y_type) | ||||||
|
||||||
if labels is None: | ||||||
labels = unique_labels(y_true, y_pred) | ||||||
else: | ||||||
labels = np.asarray(labels) | ||||||
n_labels = labels.size | ||||||
n_labels = size(labels) | ||||||
if n_labels == 0: | ||||||
raise ValueError("'labels' should contains at least one label.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
unrelated typo fix |
||||||
elif y_true.size == 0: | ||||||
return np.zeros((n_labels, n_labels), dtype=int) | ||||||
elif len(np.intersect1d(y_true, labels)) == 0: | ||||||
|
||||||
if size(y_true) == 0: | ||||||
return xp.zeros((n_labels, n_labels), dtype=xp.int32, device=device) | ||||||
|
||||||
if size(_intersect1d(y_true, labels, xp=xp)) == 0: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens here if As a user I'd probably provide the labels as a Python list/tuple, mostly because it is convenient and not performance critical. Is there a downside to being helpful to the callers and allowing list/tuple here? |
||||||
raise ValueError("At least one label specified must be in y_true") | ||||||
|
||||||
if sample_weight is None: | ||||||
sample_weight = np.ones(y_true.shape[0], dtype=np.int64) | ||||||
sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64, device=device) | ||||||
else: | ||||||
sample_weight = np.asarray(sample_weight) | ||||||
sample_weight = xp.asarray(sample_weight, device=device) | ||||||
|
||||||
check_consistent_length(y_true, y_pred, sample_weight) | ||||||
|
||||||
n_labels = labels.size | ||||||
n_labels = size(labels) | ||||||
# If labels are not consecutive integers starting from zero, then | ||||||
# y_true and y_pred must be converted into index form | ||||||
need_index_conversion = not ( | ||||||
labels.dtype.kind in {"i", "u", "b"} | ||||||
and np.all(labels == np.arange(n_labels)) | ||||||
and y_true.min() >= 0 | ||||||
and y_pred.min() >= 0 | ||||||
xp.isdtype(labels.dtype, ("bool", "integral")) | ||||||
and xp.all(labels == xp.arange(n_labels, dtype=labels.dtype, device=device)) | ||||||
and xp.min(y_true) >= 0 | ||||||
and xp.min(y_pred) >= 0 | ||||||
) | ||||||
if need_index_conversion: | ||||||
label_to_ind = {y: x for x, y in enumerate(labels)} | ||||||
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) | ||||||
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) | ||||||
mapped_preds = [label_to_ind.get(x, n_labels + 1) for x in y_pred] | ||||||
mapped_truth = [label_to_ind.get(x, n_labels + 1) for x in y_true] | ||||||
y_pred = xp.asarray(mapped_preds, device=device) | ||||||
y_true = xp.asarray(mapped_truth, device=device) | ||||||
|
||||||
# intersect y_pred, y_true with labels, eliminate items not in labels | ||||||
ind = np.logical_and(y_pred < n_labels, y_true < n_labels) | ||||||
if not np.all(ind): | ||||||
ind = xp.logical_and(y_pred < n_labels, y_true < n_labels) | ||||||
if not xp.all(ind): | ||||||
y_pred = y_pred[ind] | ||||||
y_true = y_true[ind] | ||||||
# also eliminate weights of eliminated items | ||||||
sample_weight = sample_weight[ind] | ||||||
|
||||||
# Choose the accumulator dtype to always have high precision | ||||||
if sample_weight.dtype.kind in {"i", "u", "b"}: | ||||||
dtype = np.int64 | ||||||
if xp.isdtype(sample_weight.dtype, ("bool", "integral")): | ||||||
dtype = xp.int64 | ||||||
else: | ||||||
dtype = np.float64 | ||||||
|
||||||
cm = coo_matrix( | ||||||
(sample_weight, (y_true, y_pred)), | ||||||
shape=(n_labels, n_labels), | ||||||
dtype=dtype, | ||||||
).toarray() | ||||||
dtype = xp.float64 | ||||||
|
||||||
if _is_numpy_namespace(xp): | ||||||
cm = coo_matrix( | ||||||
(sample_weight, (y_true, y_pred)), | ||||||
shape=(n_labels, n_labels), | ||||||
dtype=dtype, | ||||||
).toarray() | ||||||
else: | ||||||
# Harder and slower way. | ||||||
cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=device) | ||||||
for r, c, weight in zip(y_true, y_pred, sample_weight): | ||||||
cm[r, c] += weight | ||||||
|
||||||
with np.errstate(all="ignore"): | ||||||
if normalize == "true": | ||||||
cm = cm / cm.sum(axis=1, keepdims=True) | ||||||
cm = cm / xp.sum(cm, axis=1, keepdims=True) | ||||||
elif normalize == "pred": | ||||||
cm = cm / cm.sum(axis=0, keepdims=True) | ||||||
cm = cm / xp.sum(cm, axis=0, keepdims=True) | ||||||
elif normalize == "all": | ||||||
cm = cm / cm.sum() | ||||||
cm = np.nan_to_num(cm) | ||||||
cm = cm / xp.sum(cm) | ||||||
cm = _nan_to_num(cm, xp=xp) | ||||||
|
||||||
if cm.shape == (1, 1): | ||||||
warnings.warn( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated grammar fix (I think)