Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made masked losses compatible with masked nans #18829

Closed
wants to merge 3 commits into from

Conversation

jackd
Copy link
Contributor

@jackd jackd commented Nov 25, 2023

NaNs are a great way to ensure certain values aren't used (e.g. those that are associated with masked values). This change ensures that masked values are correctly masked (set to zero, even when nan) rather than multiplied by zero (which leaves nans as nans).

This PR also (IMO) greatly simplifies the masking / weighting loss implementation. Test coverage is also improved.

@codecov-commenter
Copy link

codecov-commenter commented Nov 25, 2023

Codecov Report

Attention: 3 lines in your changes are missing coverage. Please review.

Comparison is base (9620d23) 79.30% compared to head (936dd13) 79.37%.
Report is 11 commits behind head on master.

Files Patch % Lines
keras/losses/loss.py 86.95% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18829      +/-   ##
==========================================
+ Coverage   79.30%   79.37%   +0.06%     
==========================================
  Files         336      336              
  Lines       34549    34775     +226     
  Branches     6799     6841      +42     
==========================================
+ Hits        27400    27603     +203     
- Misses       5567     5590      +23     
  Partials     1582     1582              
Flag Coverage Δ
keras 79.23% <90.00%> (+0.06%) ⬆️
keras-jax 61.07% <90.00%> (-0.28%) ⬇️
keras-numpy 55.90% <40.00%> (-0.19%) ⬇️
keras-tensorflow 63.26% <90.00%> (-0.08%) ⬇️
keras-torch 63.83% <90.00%> (-0.27%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

allowed = {"sum_over_batch_size", "sum", None, "none"}
raise ValueError(
"Invalid value for argument `reduction`. "
f"Expected on of {allowed}. Received: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one of

@@ -13,8 +13,11 @@ def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype):
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, dtype=dtype)
if mask is not None:
sample_weight = losses.loss.apply_mask(
sample_weight, mask, dtype=dtype, reduction="sum"
sample_weight, mask = losses.loss.squeeze_to_same_rank(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not refactor apply_mask instead?

batch_size = ops.count_nonzero(mask)
values_sum = ops.sum(values)
# safe divide
return ops.cond(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add significant overhead to the computation. Is there a better way that doesn't involve a cond?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are plenty of ways to do essentially this, but is a cond on scalars that expensive? What about a where? They seemed the clearest, but other options would be:

  • use batch_size = maximum(batch_size, 1) - should give identical results (batch size is an integer, and when it's 0 the numerator should be zero)
  • use batch_size = batch_size + epsilon - closer to the current implementation, though IMO wrong.

@jackd
Copy link
Contributor Author

jackd commented Nov 29, 2023

@fchollet refactored to use a re-introduced apply_mask. Can you ellaborate on the overhead introduced by a scalar cond? Is there something about the specific circumstances here that make it expensive? I would have thought it would be cheaper than a where, which itself would be about the cost of any simple binary arithmetic operation

@fchollet
Copy link
Member

The reason has to do with parallelization. Conditional branches are much harder to handle than a single serial op even if it has more flops.

@haifeng-jin can you advise here on how to test the performance impact of this change for a couple of standard models on GPU? I'd like to compare the cond implementation, the where implementation, and the baseline (current code).

@sachinprasadhs sachinprasadhs added stat:awaiting keras-eng Awaiting response from Keras engineer and removed awaiting review labels Dec 5, 2023
@haifeng-jin
Copy link
Member

@jackd here is what I did for benchmarking a PR.
You can find 6 colab notebooks here:
https://github.com/haifeng-jin/keras-benchmarking

Each of the notebooks are for one of the backends, either before the change or after the change.
You can just swap out the model part of the code and use your own model.

You will see the perf at the end of the notebook.

@fchollet
Copy link
Member

fchollet commented Dec 7, 2023

Thanks, Haifeng! @jackd can you use the same code to benchmark the impact of this change?

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Dec 8, 2023
@jackd
Copy link
Contributor Author

jackd commented Dec 13, 2023

This isn't a priority for me, and I've already spent a lot longer on this than I intended to. If anyone else wants to take this up feel free, otherwise would a modified PR with just the simplified version (re-replacing masking with multiplication by zero) be accepted?

@gbaned
Copy link
Collaborator

gbaned commented Jan 5, 2024

Hi @fchollet Can you please assist on above comments from @jackd. Thank you!

@dugujiujian1999
Copy link
Contributor

Running on colab T4

  • batch_norm_op_jax_after (HEAD -> 936dd13)
    102967424/102967424 ━━━━━━━━━━━━━━━━━━━━ 6s 0us/step
    101/101 ━━━━━━━━━━━━━━━━━━━━ 165s 1s/step - loss: 0.5308
    training: 1111 ms/step
    101/101 ━━━━━━━━━━━━━━━━━━━━ 33s 269ms/step
    inferencing: 267 ms/step
  • batch_norm_op_jax_before (HEAD -> 9620d23)
    102967424/102967424 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
    101/101 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - loss: 0.5062
    training: 1076 ms/step
    101/101 ━━━━━━━━━━━━━━━━━━━━ 41s 373ms/step
    inferencing: 255 ms/step

  • batch_norm_op_oom_torch_after (HEAD -> 936dd13)
    101/101 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - loss: 0.6557
    414.0576171875
  • batch_norm_op_oom_torch_before (HEAD -> 9620d23)
    101/101 ━━━━━━━━━━━━━━━━━━━━ 109s 1s/step - loss: 0.5071
    414.0576171875

  • batch_norm_op_torch_after (HEAD -> 936dd13)
    101/101 ━━━━━━━━━━━━━━━━━━━━ 71s 684ms/step - loss: 0.3952
    training: 683 ms/step
    101/101 ━━━━━━━━━━━━━━━━━━━━ 19s 189ms/step
    inferencing: 189 ms/step
  • batch_norm_op_torch_before (HEAD -> 9620d23)
    101/101 ━━━━━━━━━━━━━━━━━━━━ 93s 545ms/step - loss: 0.5390
    training: 543 ms/step
    101/101 ━━━━━━━━━━━━━━━━━━━━ 18s 132ms/step
    inferencing: 132 ms/step

@jackd
Copy link
Contributor Author

jackd commented Jan 15, 2024

@dugujiujian1999 maybe I'm missing something, but why do run with lower times have a higher times per step? e.g. batch_norm_op_torch training has 71s & 684ms/step before and 93s & 545ms/step. Surely they should be proportional?

@dugujiujian1999
Copy link
Contributor

dugujiujian1999 commented Jan 15, 2024

@jackd i don't know. It takes less time after the patch. i use the code there:
https://github.com/haifeng-jin/keras-benchmarking/tree/main/prs

@sachinprasadhs
Copy link
Collaborator

@jackd , Can you please rebase the code to follow latest code structure like /keras/src/..

@sachinprasadhs sachinprasadhs added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels May 1, 2024
@jackd
Copy link
Contributor Author

jackd commented May 1, 2024

CBF at this point, someone else can take over if they'd like.

@jackd jackd closed this May 1, 2024
PR Queue automation moved this from Assigned Reviewer to Closed/Rejected May 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Closed/Rejected
Development

Successfully merging this pull request may close these issues.

None yet

7 participants