Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made masked losses compatible with masked nans #18829

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 39 additions & 59 deletions keras/losses/loss.py
Expand Up @@ -76,7 +76,7 @@ def standardize_reduction(reduction):
if reduction not in allowed:
raise ValueError(
"Invalid value for argument `reduction`. "
f"Expected on of {allowed}. Received: "
f"Expected one of {allowed}. Received: "
f"reduction={reduction}"
)
return reduction
Expand All @@ -97,23 +97,6 @@ def squeeze_to_same_rank(x1, x2):
return x1, x2


def reduce_values(values, reduction="sum_over_batch_size"):
if (
reduction is None
or reduction == "none"
or tuple(values.shape) == ()
or tuple(values.shape) == (0,)
):
return values
loss = ops.sum(values)
if reduction == "sum_over_batch_size":
loss /= ops.cast(
ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")),
loss.dtype,
)
return loss


def reduce_weighted_values(
values,
sample_weight=None,
Expand All @@ -126,48 +109,45 @@ def reduce_weighted_values(
values = ops.convert_to_tensor(values, dtype=dtype)
if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)
if mask is not None:
mask = ops.convert_to_tensor(mask, dtype=dtype)
sample_weight, values = squeeze_to_same_rank(sample_weight, values)
values = values * sample_weight

values = apply_mask(values, mask)

if reduction is None or reduction == "none":
return values

# Merge mask and sample weight into sample weight.
sample_weight = apply_mask(
sample_weight, mask, dtype=values.dtype, reduction=reduction
if reduction == "sum":
return ops.sum(values)

if reduction == "sum_over_batch_size":
if mask is None:
# batch_size is the total number of elements
return ops.mean(values)
batch_size = ops.count_nonzero(mask)
values_sum = ops.sum(values)
# safe divide
return ops.cond(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add significant overhead to the computation. Is there a better way that doesn't involve a cond?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are plenty of ways to do essentially this, but is a cond on scalars that expensive? What about a where? They seemed the clearest, but other options would be:

  • use batch_size = maximum(batch_size, 1) - should give identical results (batch size is an integer, and when it's 0 the numerator should be zero)
  • use batch_size = batch_size + epsilon - closer to the current implementation, though IMO wrong.

batch_size == 0,
lambda: values_sum, # will necessarily be all zeros
lambda: values_sum / ops.cast(batch_size, dtype),
)

# we shouldn't get here because the call to `standardize_reduction`
# at the top of this function should raise the exact error as below.
allowed = {"sum_over_batch_size", "sum", None, "none"}
raise ValueError(
"Invalid value for argument `reduction`. "
f"Expected one of {allowed}. Received: "
f"reduction={reduction}"
)

if sample_weight is not None:
sample_weight = ops.cast(sample_weight, values.dtype)
# Update dimensions of `sample_weight` to match `losses`.
values, sample_weight = squeeze_to_same_rank(values, sample_weight)
values = values * sample_weight

# Apply reduction function to the individual weighted losses.
loss = reduce_values(values, reduction)
return loss


def apply_mask(sample_weight, mask, dtype, reduction):
"""Applies any mask on predictions to sample weights."""
if mask is not None:
mask = ops.cast(mask, dtype=dtype)
if reduction == "sum_over_batch_size":
# Valid entries have weight `total/valid`, while invalid ones
# have 0. When summed over batch, they will be reduced to:
#
# mean(loss * sample_weight * total / valid)
# = sum(loss * sample_weight * total / valid) / total
# = sum(loss * sample_weight) / total * total / valid
# = sum(loss * sample_weight) / valid
total = ops.cast(
ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")),
dtype,
)
valid = ops.sum(mask) # May be 0!
mask *= total / (valid + backend.epsilon())

if sample_weight is not None:
sample_weight = ops.cast(sample_weight, dtype=dtype)
mask, sample_weight = squeeze_to_same_rank(mask, sample_weight)
sample_weight *= mask
else:
sample_weight = mask
return sample_weight
def apply_mask(values, mask):
if mask is None:
return values
mask = ops.cast(mask, "bool")
while len(mask.shape) < len(values.shape):
mask = ops.expand_dims(mask, axis=-1)
values, mask = squeeze_to_same_rank(values, mask)
return ops.where(mask, values, ops.zeros_like(values))
83 changes: 81 additions & 2 deletions keras/losses/loss_test.py
Expand Up @@ -64,9 +64,23 @@ def test_mask(self):
np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss
)

# no reduction
loss_fn = ExampleLoss(reduction=None)
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
expected = (y_true - y_pred) ** 2
expected = ops.where(mask, expected, ops.zeros_like(expected))
self.assertAllClose(expected, loss)

# sum reduction
loss_fn = ExampleLoss(reduction="sum")
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(np.sum((masked_y_true - masked_y_pred) ** 2), loss)

# Test edge case where everything is masked.
mask = np.array([False, False, False, False])
y_pred._keras_mask = mask
loss_fn = ExampleLoss()
y_pred._keras_mask = np.array([False, False, False, False])
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(loss, 0) # No NaN.
Expand All @@ -83,7 +97,22 @@ def test_sample_weight(self):
np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss
)

# no reduction
loss_fn = ExampleLoss(reduction=None)
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(sample_weight * (y_true - y_pred) ** 2, loss)

# sum reduction
loss_fn = ExampleLoss(reduction="sum")
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(
ops.sum(sample_weight * (y_true - y_pred) ** 2), loss
)

# Test edge case where every weight is 0.
loss_fn = ExampleLoss()
sample_weight = np.array([0.0, 0.0, 0.0, 0.0])
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
Expand Down Expand Up @@ -117,6 +146,56 @@ def test_mask_and_sample_weight(self):
loss,
)

# ensure the result is the same if `y_pred` has masked nans.
y_pred_with_nans = ops.where(
mask, y_pred, ops.full_like(y_pred, np.nan)
)
y_pred_with_nans._keras_mask = mask
loss_with_y_pred_nans = loss_fn(
y_true, y_pred_with_nans, sample_weight=sample_weight
)
self.assertEqual(
backend.standardize_dtype(loss_with_y_pred_nans.dtype), "float32"
)
self.assertAllClose(loss, loss_with_y_pred_nans)

# ensure the result is the same if `sample_weights` has masked nans.
sample_weight_with_nans = ops.where(
mask, sample_weight, ops.full_like(sample_weight, np.nan)
)
loss_with_sample_weight_nans = loss_fn(
y_true, y_pred, sample_weight=sample_weight_with_nans
)
self.assertEqual(
backend.standardize_dtype(loss_with_sample_weight_nans.dtype),
"float32",
)
self.assertAllClose(loss, loss_with_sample_weight_nans)

# reduction is None
loss_fn = ExampleLoss(reduction="none")
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(
ops.cast(mask, sample_weight.dtype)
* sample_weight
* (y_true - y_pred) ** 2,
loss,
)

# reduction is 'sum'
loss_fn = ExampleLoss(reduction="sum")
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(
ops.sum(
ops.cast(mask, sample_weight.dtype)
* sample_weight
* (y_true - y_pred) ** 2
),
loss,
)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
Expand Down
16 changes: 8 additions & 8 deletions keras/metrics/reduction_metrics.py
Expand Up @@ -10,12 +10,10 @@
def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype):
mask = getattr(values, "_keras_mask", None)
values = ops.cast(values, dtype=dtype)
values = losses.loss.apply_mask(values, mask)
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, dtype=dtype)
if mask is not None:
sample_weight = losses.loss.apply_mask(
sample_weight, mask, dtype=dtype, reduction="sum"
)
sample_weight = losses.loss.apply_mask(sample_weight, mask)
# Update dimensions of weights to match with values if possible.
values, sample_weight = losses.loss.squeeze_to_same_rank(
values, sample_weight
Expand Down Expand Up @@ -200,10 +198,12 @@ def __init__(self, fn, name=None, dtype=None, **kwargs):
def update_state(self, y_true, y_pred, sample_weight=None):
mask = getattr(y_pred, "_keras_mask", None)
values = self._fn(y_true, y_pred, **self._fn_kwargs)
if sample_weight is not None and mask is not None:
sample_weight = losses.loss.apply_mask(
sample_weight, mask, dtype=self.dtype, reduction="sum"
)
if mask is not None:
values = losses.loss.apply_mask(values, mask)
if sample_weight is None:
sample_weight = ops.cast(mask, values.dtype)
else:
sample_weight = losses.loss.apply_mask(sample_weight, mask)
return super().update_state(values, sample_weight=sample_weight)

def get_config(self):
Expand Down
57 changes: 57 additions & 0 deletions keras/metrics/reduction_metrics_test.py
@@ -1,5 +1,8 @@
import numpy as np
import pytest

from keras import backend
from keras import ops
from keras import testing
from keras.metrics import reduction_metrics
from keras.saving import register_keras_serializable
Expand Down Expand Up @@ -121,3 +124,57 @@ def test_weighted(self):
sample_weight = np.array([1.0, 1.5, 2.0, 2.5])
result = mse_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(0.54285, result, atol=1e-5)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_masked(self):
mse_obj = reduction_metrics.MeanMetricWrapper(
fn=mse, name="mse", dtype="float32"
)
y_true = np.array(
[[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]
)
y_pred = np.array(
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
)
mask = np.array([True, True, False, True])
expected = ((y_true[mask] - y_pred[mask]) ** 2).mean()

y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true)
mask = ops.convert_to_tensor(mask)

y_pred._keras_mask = mask
result = mse_obj(y_true, y_pred)
self.assertAllClose(expected, result, atol=1e-5)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_masked_and_weighted(self):
mse_obj = reduction_metrics.MeanMetricWrapper(
fn=mse, name="mse", dtype="float32"
)
y_true = np.array(
[[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]
)
y_pred = np.array(
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
)
mask = np.array([True, True, False, True])
sample_weight = np.array([1.0, 1.5, 2.0, 2.5])
expected = (
(y_true[mask] - y_pred[mask]) ** 2 * sample_weight[mask][:, None]
).mean(1).sum() / sample_weight[mask].sum()

y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true)
sample_weight = ops.convert_to_tensor(sample_weight)
mask = ops.convert_to_tensor(mask)

y_pred._keras_mask = mask
result = mse_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(expected, result, atol=1e-5)