Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made masked losses compatible with masked nans #18829

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 30 additions & 57 deletions keras/losses/loss.py
Expand Up @@ -97,23 +97,6 @@ def squeeze_to_same_rank(x1, x2):
return x1, x2


def reduce_values(values, reduction="sum_over_batch_size"):
if (
reduction is None
or reduction == "none"
or tuple(values.shape) == ()
or tuple(values.shape) == (0,)
):
return values
loss = ops.sum(values)
if reduction == "sum_over_batch_size":
loss /= ops.cast(
ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")),
loss.dtype,
)
return loss


def reduce_weighted_values(
values,
sample_weight=None,
Expand All @@ -126,48 +109,38 @@ def reduce_weighted_values(
values = ops.convert_to_tensor(values, dtype=dtype)
if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)
if mask is not None:
mask = ops.convert_to_tensor(mask, dtype=dtype)
sample_weight, values = squeeze_to_same_rank(sample_weight, values)
values = values * sample_weight

# Merge mask and sample weight into sample weight.
sample_weight = apply_mask(
sample_weight, mask, dtype=values.dtype, reduction=reduction
)
if mask is not None:
mask = ops.cast(mask, "bool")
mask, values = squeeze_to_same_rank(mask, values)
values = ops.where(mask, values, ops.zeros_like(values))

if sample_weight is not None:
sample_weight = ops.cast(sample_weight, values.dtype)
# Update dimensions of `sample_weight` to match `losses`.
values, sample_weight = squeeze_to_same_rank(values, sample_weight)
values = values * sample_weight
if reduction is None or reduction == "none":
return values

# Apply reduction function to the individual weighted losses.
loss = reduce_values(values, reduction)
return loss
if reduction == "sum":
return ops.sum(values)

if reduction == "sum_over_batch_size":
if mask is None:
# batch_size is the total number of elements
return ops.mean(values)
batch_size = ops.count_nonzero(mask)
values_sum = ops.sum(values)
# safe divide
return ops.cond(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add significant overhead to the computation. Is there a better way that doesn't involve a cond?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are plenty of ways to do essentially this, but is a cond on scalars that expensive? What about a where? They seemed the clearest, but other options would be:

  • use batch_size = maximum(batch_size, 1) - should give identical results (batch size is an integer, and when it's 0 the numerator should be zero)
  • use batch_size = batch_size + epsilon - closer to the current implementation, though IMO wrong.

batch_size == 0,
lambda: values_sum, # will necessarily be all zeros
lambda: values_sum / ops.cast(batch_size, dtype),
)

def apply_mask(sample_weight, mask, dtype, reduction):
"""Applies any mask on predictions to sample weights."""
if mask is not None:
mask = ops.cast(mask, dtype=dtype)
if reduction == "sum_over_batch_size":
# Valid entries have weight `total/valid`, while invalid ones
# have 0. When summed over batch, they will be reduced to:
#
# mean(loss * sample_weight * total / valid)
# = sum(loss * sample_weight * total / valid) / total
# = sum(loss * sample_weight) / total * total / valid
# = sum(loss * sample_weight) / valid
total = ops.cast(
ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")),
dtype,
)
valid = ops.sum(mask) # May be 0!
mask *= total / (valid + backend.epsilon())

if sample_weight is not None:
sample_weight = ops.cast(sample_weight, dtype=dtype)
mask, sample_weight = squeeze_to_same_rank(mask, sample_weight)
sample_weight *= mask
else:
sample_weight = mask
return sample_weight
# we shouldn't get here because the call to `standardize_reduction`
# at the top of this function should raise the exact error as below.
allowed = {"sum_over_batch_size", "sum", None, "none"}
raise ValueError(
"Invalid value for argument `reduction`. "
f"Expected on of {allowed}. Received: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one of

f"reduction={reduction}"
)
83 changes: 81 additions & 2 deletions keras/losses/loss_test.py
Expand Up @@ -64,9 +64,23 @@ def test_mask(self):
np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss
)

# no reduction
loss_fn = ExampleLoss(reduction=None)
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
expected = (y_true - y_pred) ** 2
expected = ops.where(mask, expected, ops.zeros_like(expected))
self.assertAllClose(expected, loss)

# sum reduction
loss_fn = ExampleLoss(reduction="sum")
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(np.sum((masked_y_true - masked_y_pred) ** 2), loss)

# Test edge case where everything is masked.
mask = np.array([False, False, False, False])
y_pred._keras_mask = mask
loss_fn = ExampleLoss()
y_pred._keras_mask = np.array([False, False, False, False])
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(loss, 0) # No NaN.
Expand All @@ -83,7 +97,22 @@ def test_sample_weight(self):
np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss
)

# no reduction
loss_fn = ExampleLoss(reduction=None)
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(sample_weight * (y_true - y_pred) ** 2, loss)

# sum reduction
loss_fn = ExampleLoss(reduction="sum")
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(
ops.sum(sample_weight * (y_true - y_pred) ** 2), loss
)

# Test edge case where every weight is 0.
loss_fn = ExampleLoss()
sample_weight = np.array([0.0, 0.0, 0.0, 0.0])
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
Expand Down Expand Up @@ -117,6 +146,56 @@ def test_mask_and_sample_weight(self):
loss,
)

# ensure the result is the same if `y_pred` has masked nans.
y_pred_with_nans = ops.where(
mask, y_pred, ops.full_like(y_pred, np.nan)
)
y_pred_with_nans._keras_mask = mask
loss_with_y_pred_nans = loss_fn(
y_true, y_pred_with_nans, sample_weight=sample_weight
)
self.assertEqual(
backend.standardize_dtype(loss_with_y_pred_nans.dtype), "float32"
)
self.assertAllClose(loss, loss_with_y_pred_nans)

# ensure the result is the same if `sample_weights` has masked nans.
sample_weight_with_nans = ops.where(
mask, sample_weight, ops.full_like(sample_weight, np.nan)
)
loss_with_sample_weight_nans = loss_fn(
y_true, y_pred, sample_weight=sample_weight_with_nans
)
self.assertEqual(
backend.standardize_dtype(loss_with_sample_weight_nans.dtype),
"float32",
)
self.assertAllClose(loss, loss_with_sample_weight_nans)

# reduction is None
loss_fn = ExampleLoss(reduction="none")
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(
ops.cast(mask, sample_weight.dtype)
* sample_weight
* (y_true - y_pred) ** 2,
loss,
)

# reduction is 'sum'
loss_fn = ExampleLoss(reduction="sum")
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(
ops.sum(
ops.cast(mask, sample_weight.dtype)
* sample_weight
* (y_true - y_pred) ** 2
),
loss,
)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
Expand Down
14 changes: 10 additions & 4 deletions keras/metrics/reduction_metrics.py
Expand Up @@ -13,8 +13,11 @@ def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype):
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, dtype=dtype)
if mask is not None:
sample_weight = losses.loss.apply_mask(
sample_weight, mask, dtype=dtype, reduction="sum"
sample_weight, mask = losses.loss.squeeze_to_same_rank(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not refactor apply_mask instead?

sample_weight, mask
)
sample_weight = ops.where(
mask, sample_weight, ops.zeros_like(sample_weight)
)
# Update dimensions of weights to match with values if possible.
values, sample_weight = losses.loss.squeeze_to_same_rank(
Expand Down Expand Up @@ -201,8 +204,11 @@ def update_state(self, y_true, y_pred, sample_weight=None):
mask = getattr(y_pred, "_keras_mask", None)
values = self._fn(y_true, y_pred, **self._fn_kwargs)
if sample_weight is not None and mask is not None:
sample_weight = losses.loss.apply_mask(
sample_weight, mask, dtype=self.dtype, reduction="sum"
sample_weight, mask = losses.loss.squeeze_to_same_rank(
sample_weight, mask
)
sample_weight = ops.where(
mask, sample_weight, ops.zeros_like(sample_weight)
)
return super().update_state(values, sample_weight=sample_weight)

Expand Down