Skip to content

Commit

Permalink
ENH Use Array API in mean_tweedie_deviance (#28106)
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed May 8, 2024
1 parent 1fa3c75 commit e12f192
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`

Expand Down
18 changes: 18 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ Version 1.6.0

**In Development**

Support for Array API
---------------------

Additional estimators and functions have been updated to include support for all
`Array API <https://data-apis.org/array-api/latest/>`_ compliant inputs.

See :ref:`array_api` for more details.

**Functions:**

- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
inputs.
:pr:`28106` by :user:`Thomas Li <lithomas1>`

**Classes:**

-

Changelog
---------

Expand Down
21 changes: 11 additions & 10 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,13 +1276,14 @@ def max_error(y_true, y_pred):

def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
"""Mean Tweedie deviance regression loss."""
xp, _ = get_namespace(y_true, y_pred)
p = power
if p < 0:
# 'Extreme stable', y any real number, y_pred > 0
dev = 2 * (
np.power(np.maximum(y_true, 0), 2 - p) / ((1 - p) * (2 - p))
- y_true * np.power(y_pred, 1 - p) / (1 - p)
+ np.power(y_pred, 2 - p) / (2 - p)
xp.pow(xp.where(y_true > 0, y_true, 0), 2 - p) / ((1 - p) * (2 - p))
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
+ xp.pow(y_pred, 2 - p) / (2 - p)
)
elif p == 0:
# Normal distribution, y and y_pred any real number
Expand All @@ -1292,15 +1293,14 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
dev = 2 * (xlogy(y_true, y_true / y_pred) - y_true + y_pred)
elif p == 2:
# Gamma distribution
dev = 2 * (np.log(y_pred / y_true) + y_true / y_pred - 1)
dev = 2 * (xp.log(y_pred / y_true) + y_true / y_pred - 1)
else:
dev = 2 * (
np.power(y_true, 2 - p) / ((1 - p) * (2 - p))
- y_true * np.power(y_pred, 1 - p) / (1 - p)
+ np.power(y_pred, 2 - p) / (2 - p)
xp.pow(y_true, 2 - p) / ((1 - p) * (2 - p))
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
+ xp.pow(y_pred, 2 - p) / (2 - p)
)

return np.average(dev, weights=sample_weight)
return float(_average(dev, weights=sample_weight))


@validate_params(
Expand Down Expand Up @@ -1363,8 +1363,9 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
>>> mean_tweedie_deviance(y_true, y_pred, power=1)
1.4260...
"""
xp, _ = get_namespace(y_true, y_pred)
y_type, y_true, y_pred, _ = _check_reg_targets(
y_true, y_pred, None, dtype=[np.float64, np.float32]
y_true, y_pred, None, dtype=[xp.float64, xp.float32]
)
if y_type == "continuous-multioutput":
raise ValueError("Multioutput not supported in mean_tweedie_deviance")
Expand Down
35 changes: 34 additions & 1 deletion sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,35 @@ def check_array_api_multiclass_classification_metric(


def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
y_true_np = np.array([2, 0, 1, 4], dtype=dtype_name)
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
sample_weight=None,
)

sample_weight = np.array([0.1, 2.0, 1.5, 0.5], dtype=dtype_name)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
sample_weight=sample_weight,
)


def check_array_api_regression_metric_multioutput(
metric, array_namespace, device, dtype_name
):
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name)
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name)

Expand Down Expand Up @@ -1859,7 +1888,11 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
r2_score: [check_array_api_regression_metric],
mean_tweedie_deviance: [check_array_api_regression_metric],
r2_score: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
}


Expand Down
3 changes: 3 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ def reshape(self, x, shape, *, copy=None):
def isdtype(self, dtype, kind):
return isdtype(dtype, kind, xp=self)

def pow(self, x1, x2):
return numpy.power(x1, x2)


_NUMPY_API_WRAPPER_INSTANCE = _NumPyAPIWrapper()

Expand Down

0 comments on commit e12f192

Please sign in to comment.