New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add array-api support to metrics.confusion_matrix #28867
base: main
Are you sure you want to change the base?
Changes from 4 commits
9166240
85163a5
ae9d207
5cfb85f
e258ba1
864ffea
aa58127
f5565f3
2ea2722
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -40,8 +40,15 @@ | |||||
) | ||||||
from ..utils._array_api import ( | ||||||
_average, | ||||||
_is_numpy_namespace, | ||||||
_nan_to_num, | ||||||
_union1d, | ||||||
get_namespace, | ||||||
isdtype, | ||||||
size, | ||||||
) | ||||||
from ..utils._array_api import ( | ||||||
device as array_device, | ||||||
) | ||||||
from ..utils._param_validation import ( | ||||||
Hidden, | ||||||
|
@@ -111,10 +118,10 @@ def _check_targets(y_true, y_pred): | |||||
y_type = y_type.pop() | ||||||
|
||||||
# No metrics support "multiclass-multioutput" format | ||||||
if y_type not in ["binary", "multiclass", "multilabel-indicator"]: | ||||||
if y_type not in {"binary", "multiclass", "multilabel-indicator"}: | ||||||
raise ValueError("{0} is not supported".format(y_type)) | ||||||
|
||||||
if y_type in ["binary", "multiclass"]: | ||||||
if y_type in {"binary", "multiclass"}: | ||||||
xp, _ = get_namespace(y_true, y_pred) | ||||||
y_true = column_or_1d(y_true) | ||||||
y_pred = column_or_1d(y_pred) | ||||||
|
@@ -128,8 +135,8 @@ def _check_targets(y_true, y_pred): | |||||
# strings. So we raise a meaningful error | ||||||
raise TypeError( | ||||||
"Labels in y_true and y_pred should be of the same type. " | ||||||
f"Got y_true={xp.unique(y_true)} and " | ||||||
f"y_pred={xp.unique(y_pred)}. Make sure that the " | ||||||
f"Got y_true={xp.unique_values(y_true)} and " | ||||||
f"y_pred={xp.unique_values(y_pred)}. Make sure that the " | ||||||
"predictions provided by the classifier coincides with " | ||||||
"the true labels." | ||||||
) from e | ||||||
|
@@ -322,6 +329,7 @@ def confusion_matrix( | |||||
(0, 2, 1, 1) | ||||||
""" | ||||||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||||||
xp, _ = get_namespace(y_true, y_pred) | ||||||
if y_type not in ("binary", "multiclass"): | ||||||
raise ValueError("%s is not supported" % y_type) | ||||||
|
||||||
|
@@ -332,60 +340,72 @@ def confusion_matrix( | |||||
n_labels = labels.size | ||||||
if n_labels == 0: | ||||||
raise ValueError("'labels' should contains at least one label.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
unrelated typo fix |
||||||
elif y_true.size == 0: | ||||||
elif size(y_true) == 0: | ||||||
return np.zeros((n_labels, n_labels), dtype=int) | ||||||
elif len(np.intersect1d(y_true, labels)) == 0: | ||||||
raise ValueError("At least one label specified must be in y_true") | ||||||
|
||||||
if sample_weight is None: | ||||||
sample_weight = np.ones(y_true.shape[0], dtype=np.int64) | ||||||
sample_weight = xp.ones( | ||||||
y_true.shape[0], dtype=xp.int64, device=array_device(y_true) | ||||||
) | ||||||
else: | ||||||
sample_weight = np.asarray(sample_weight) | ||||||
sample_weight = xp.asarray(sample_weight) | ||||||
|
||||||
check_consistent_length(y_true, y_pred, sample_weight) | ||||||
|
||||||
n_labels = labels.size | ||||||
n_labels = size(labels) | ||||||
# If labels are not consecutive integers starting from zero, then | ||||||
# y_true and y_pred must be converted into index form | ||||||
need_index_conversion = not ( | ||||||
labels.dtype.kind in {"i", "u", "b"} | ||||||
and np.all(labels == np.arange(n_labels)) | ||||||
and y_true.min() >= 0 | ||||||
and y_pred.min() >= 0 | ||||||
xp.isdtype(labels.dtype, ("bool", "integral")) | ||||||
and xp.all( | ||||||
labels | ||||||
== xp.arange(n_labels, dtype=labels.dtype, device=array_device(labels)) | ||||||
) | ||||||
and xp.min(y_true) >= 0 | ||||||
and xp.min(y_pred) >= 0 | ||||||
) | ||||||
if need_index_conversion: | ||||||
label_to_ind = {y: x for x, y in enumerate(labels)} | ||||||
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) | ||||||
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) | ||||||
y_pred = xp.asarray([label_to_ind.get(x, n_labels + 1) for x in y_pred]) | ||||||
y_true = xp.asarray([label_to_ind.get(x, n_labels + 1) for x in y_true]) | ||||||
|
||||||
# intersect y_pred, y_true with labels, eliminate items not in labels | ||||||
ind = np.logical_and(y_pred < n_labels, y_true < n_labels) | ||||||
if not np.all(ind): | ||||||
ind = xp.logical_and(y_pred < n_labels, y_true < n_labels) | ||||||
if not xp.all(ind): | ||||||
y_pred = y_pred[ind] | ||||||
y_true = y_true[ind] | ||||||
# also eliminate weights of eliminated items | ||||||
sample_weight = sample_weight[ind] | ||||||
|
||||||
# Choose the accumulator dtype to always have high precision | ||||||
if sample_weight.dtype.kind in {"i", "u", "b"}: | ||||||
dtype = np.int64 | ||||||
if isdtype(sample_weight.dtype, ("bool", "integral"), xp=xp): | ||||||
dtype = xp.int64 | ||||||
else: | ||||||
dtype = np.float64 | ||||||
|
||||||
cm = coo_matrix( | ||||||
(sample_weight, (y_true, y_pred)), | ||||||
shape=(n_labels, n_labels), | ||||||
dtype=dtype, | ||||||
).toarray() | ||||||
dtype = xp.float64 # can't use float64 for "torch on mps". | ||||||
|
||||||
# This will be a challenge to do in a "friendly" way. | ||||||
if _is_numpy_namespace(xp): | ||||||
cm = coo_matrix( | ||||||
(sample_weight, (y_true, y_pred)), | ||||||
shape=(n_labels, n_labels), | ||||||
dtype=dtype, | ||||||
).toarray() | ||||||
else: | ||||||
# Harder and slower way. | ||||||
cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=array_device(y_true)) | ||||||
for r, c, weight in zip(y_true, y_pred, sample_weight): | ||||||
cm[r, c] += weight | ||||||
|
||||||
with np.errstate(all="ignore"): | ||||||
if normalize == "true": | ||||||
cm = cm / cm.sum(axis=1, keepdims=True) | ||||||
cm = cm / xp.sum(cm, axis=1, keepdims=True) | ||||||
elif normalize == "pred": | ||||||
cm = cm / cm.sum(axis=0, keepdims=True) | ||||||
cm = cm / xp.sum(cm, axis=0, keepdims=True) | ||||||
elif normalize == "all": | ||||||
cm = cm / cm.sum() | ||||||
cm = np.nan_to_num(cm) | ||||||
cm = cm / xp.sum(cm) | ||||||
cm = _nan_to_num(cm, xp=xp) | ||||||
|
||||||
if cm.shape == (1, 1): | ||||||
warnings.warn( | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,20 @@ | |
yield array_namespace, None, None | ||
|
||
|
||
def yield_namespace_device_combinations(include_numpy_namespaces=True): | ||
"""Yield all combinations of array namespaces and their valid devices.""" | ||
|
||
for namespace in ("numpy", "array_api_strict", "cupy", "cupy.array_api", "torch"): | ||
if not include_numpy_namespaces and namespace in _NUMPY_NAMESPACE_NAMES: | ||
continue | ||
if namespace == "torch": | ||
yield namespace, "cpu" | ||
yield namespace, "cuda" | ||
yield namespace, "mps" | ||
else: | ||
yield namespace, None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for this. There once this is in, we should do a follow-up PR for occurrences of |
||
|
||
|
||
def _check_array_api_dispatch(array_api_dispatch): | ||
"""Check that array_api_compat is installed and NumPy version is compatible. | ||
|
||
|
@@ -255,6 +269,12 @@ | |
return arrays | ||
|
||
|
||
def _check_common_namespace_device(*arrays): | ||
"""Check that all arrays use the same namespace and device.""" | ||
get_namespace(*arrays) # Throws on multiple namespaces. | ||
device(*arrays) # Throws on multiple devices. | ||
|
||
|
||
class _ArrayAPIWrapper: | ||
"""sklearn specific Array API compatibility wrapper | ||
|
||
|
@@ -521,14 +541,19 @@ | |
# message in case it is missing. | ||
import array_api_compat | ||
|
||
namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True | ||
# Convert lists and tuple to numpy arrays. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make the comment so it explains why it is a good idea to do this? I think that would be helpful to have here, at least for me it is not 100% clear that we should do this |
||
arrays = [ | ||
numpy.array(arr) if isinstance(arr, (list, tuple)) else arr for arr in arrays | ||
] | ||
|
||
namespace = array_api_compat.get_namespace(*arrays) | ||
|
||
# These namespaces need additional wrapping to smooth out small differences | ||
# between implementations | ||
if namespace.__name__ in {"cupy.array_api"}: | ||
namespace = _ArrayAPIWrapper(namespace) | ||
|
||
return namespace, is_array_api_compliant | ||
return namespace, True | ||
|
||
|
||
def _expit(X, xp=None): | ||
|
@@ -690,6 +715,41 @@ | |
return X | ||
|
||
|
||
def _nan_to_num(X, *, xp=None, copy=True, nan=0.0, posinf=None, neginf=None): | ||
"""Port of np.nan_to_num for array-api""" | ||
xp, _ = get_namespace(X, xp=None) | ||
# import array_api_strict as xp | ||
X = xp.asarray(X, copy=copy) | ||
dtype = X.dtype | ||
isscaler = X.ndim == 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: scaler => scalar. |
||
|
||
iscomplex = isdtype(dtype, "complex floating", xp=xp) | ||
|
||
# If the input isn't floating, then no changes are made. | ||
if not (isdtype(dtype, "real floating", xp=xp) or iscomplex): | ||
return X[()] if isscaler else X | ||
|
||
dest = (xp.real(X), xp.imag(X)) if iscomplex else (X,) | ||
dtype_info = xp.finfo(X.dtype) | ||
maxf, minf = dtype_info.max, dtype_info.min | ||
|
||
if posinf is not None: | ||
maxf = posinf | ||
|
||
if neginf is not None: | ||
minf = neginf | ||
|
||
for d in dest: | ||
nan_mask = xp.isnan(d) | ||
inf_mask = xp.isinf(d) | ||
posinf_mask = xp.logical_and(inf_mask, d > 0) | ||
neginf_mask = xp.logical_xor(inf_mask, posinf_mask) | ||
d[nan_mask] = xp.asarray(nan) | ||
d[posinf_mask] = xp.asarray(maxf) | ||
d[neginf_mask] = xp.asarray(minf) | ||
return X[()] if isscaler else X | ||
|
||
|
||
def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): | ||
"""Helper to support the order kwarg only for NumPy-backed arrays | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated grammar fix (I think)