New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Made masked losses compatible with masked nans #18829
Closed
Closed
Changes from 1 commit
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,23 +97,6 @@ def squeeze_to_same_rank(x1, x2): | |
return x1, x2 | ||
|
||
|
||
def reduce_values(values, reduction="sum_over_batch_size"): | ||
if ( | ||
reduction is None | ||
or reduction == "none" | ||
or tuple(values.shape) == () | ||
or tuple(values.shape) == (0,) | ||
): | ||
return values | ||
loss = ops.sum(values) | ||
if reduction == "sum_over_batch_size": | ||
loss /= ops.cast( | ||
ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")), | ||
loss.dtype, | ||
) | ||
return loss | ||
|
||
|
||
def reduce_weighted_values( | ||
values, | ||
sample_weight=None, | ||
|
@@ -126,48 +109,38 @@ def reduce_weighted_values( | |
values = ops.convert_to_tensor(values, dtype=dtype) | ||
if sample_weight is not None: | ||
sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype) | ||
if mask is not None: | ||
mask = ops.convert_to_tensor(mask, dtype=dtype) | ||
sample_weight, values = squeeze_to_same_rank(sample_weight, values) | ||
values = values * sample_weight | ||
|
||
# Merge mask and sample weight into sample weight. | ||
sample_weight = apply_mask( | ||
sample_weight, mask, dtype=values.dtype, reduction=reduction | ||
) | ||
if mask is not None: | ||
mask = ops.cast(mask, "bool") | ||
mask, values = squeeze_to_same_rank(mask, values) | ||
values = ops.where(mask, values, ops.zeros_like(values)) | ||
|
||
if sample_weight is not None: | ||
sample_weight = ops.cast(sample_weight, values.dtype) | ||
# Update dimensions of `sample_weight` to match `losses`. | ||
values, sample_weight = squeeze_to_same_rank(values, sample_weight) | ||
values = values * sample_weight | ||
if reduction is None or reduction == "none": | ||
return values | ||
|
||
# Apply reduction function to the individual weighted losses. | ||
loss = reduce_values(values, reduction) | ||
return loss | ||
if reduction == "sum": | ||
return ops.sum(values) | ||
|
||
if reduction == "sum_over_batch_size": | ||
if mask is None: | ||
# batch_size is the total number of elements | ||
return ops.mean(values) | ||
batch_size = ops.count_nonzero(mask) | ||
values_sum = ops.sum(values) | ||
# safe divide | ||
return ops.cond( | ||
batch_size == 0, | ||
lambda: values_sum, # will necessarily be all zeros | ||
lambda: values_sum / ops.cast(batch_size, dtype), | ||
) | ||
|
||
def apply_mask(sample_weight, mask, dtype, reduction): | ||
"""Applies any mask on predictions to sample weights.""" | ||
if mask is not None: | ||
mask = ops.cast(mask, dtype=dtype) | ||
if reduction == "sum_over_batch_size": | ||
# Valid entries have weight `total/valid`, while invalid ones | ||
# have 0. When summed over batch, they will be reduced to: | ||
# | ||
# mean(loss * sample_weight * total / valid) | ||
# = sum(loss * sample_weight * total / valid) / total | ||
# = sum(loss * sample_weight) / total * total / valid | ||
# = sum(loss * sample_weight) / valid | ||
total = ops.cast( | ||
ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")), | ||
dtype, | ||
) | ||
valid = ops.sum(mask) # May be 0! | ||
mask *= total / (valid + backend.epsilon()) | ||
|
||
if sample_weight is not None: | ||
sample_weight = ops.cast(sample_weight, dtype=dtype) | ||
mask, sample_weight = squeeze_to_same_rank(mask, sample_weight) | ||
sample_weight *= mask | ||
else: | ||
sample_weight = mask | ||
return sample_weight | ||
# we shouldn't get here because the call to `standardize_reduction` | ||
# at the top of this function should raise the exact error as below. | ||
allowed = {"sum_over_batch_size", "sum", None, "none"} | ||
raise ValueError( | ||
"Invalid value for argument `reduction`. " | ||
f"Expected on of {allowed}. Received: " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one of |
||
f"reduction={reduction}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,11 @@ def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype): | |
if sample_weight is not None: | ||
sample_weight = ops.cast(sample_weight, dtype=dtype) | ||
if mask is not None: | ||
sample_weight = losses.loss.apply_mask( | ||
sample_weight, mask, dtype=dtype, reduction="sum" | ||
sample_weight, mask = losses.loss.squeeze_to_same_rank( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not refactor |
||
sample_weight, mask | ||
) | ||
sample_weight = ops.where( | ||
mask, sample_weight, ops.zeros_like(sample_weight) | ||
) | ||
# Update dimensions of weights to match with values if possible. | ||
values, sample_weight = losses.loss.squeeze_to_same_rank( | ||
|
@@ -201,8 +204,11 @@ def update_state(self, y_true, y_pred, sample_weight=None): | |
mask = getattr(y_pred, "_keras_mask", None) | ||
values = self._fn(y_true, y_pred, **self._fn_kwargs) | ||
if sample_weight is not None and mask is not None: | ||
sample_weight = losses.loss.apply_mask( | ||
sample_weight, mask, dtype=self.dtype, reduction="sum" | ||
sample_weight, mask = losses.loss.squeeze_to_same_rank( | ||
sample_weight, mask | ||
) | ||
sample_weight = ops.where( | ||
mask, sample_weight, ops.zeros_like(sample_weight) | ||
) | ||
return super().update_state(values, sample_weight=sample_weight) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will add significant overhead to the computation. Is there a better way that doesn't involve a cond?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are plenty of ways to do essentially this, but is a
cond
on scalars that expensive? What about awhere
? They seemed the clearest, but other options would be:batch_size = maximum(batch_size, 1)
- should give identical results (batch size is an integer, and when it's 0 the numerator should be zero)batch_size = batch_size + epsilon
- closer to the current implementation, though IMO wrong.