-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for non-incremental decoding + unit test #973
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/973
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 25b5e63 with merge base dc2b991 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good with just some small comments around testing. Do the tests verify this works with bs > 1 too?
Added this test! |
curr_input_pos = input_pos[curr_pos].unsqueeze(0) | ||
else: | ||
curr_input_pos = input_pos[: curr_pos + 1] | ||
tokens = generated_tokens.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure I need this clone...
cc someone who is better at pytorch than me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, my primary question here would be around the need to integrate non-incremental decoding to generation utils as opposed to just having this as a custom test function. The main intent is to test correctness o KV Caching which makes sense, but not sure if the testing needs to intrude into the generation utils
My concern with just re-creating an entire non-incremental decoding generation function in tests/ is that any changes to the actual generate function will need to be incorporated into the testing generation function otherwise the results will diverge. This makes it more confusing to work on generate down the road. |
Context
As documented in #959, enabling the key-value cache for a model in generation was giving reasonable results whereas passing in a model without key-value cache enabled produced junk.
The cause: Although we originally supported incremental decoding and non-incremental decoding, this was lost in a refactor somewhere. The result was that the generate function was always ONLY passing in a single token at a time for context, therefore without a key-value cache to keep those values in memory, it was essentially producing random output.
Didn't we have a test for this?: We technically had a test for the decoder here but this only compares a single forward pass. Values won't diverge until we get into the area where the key-value cache is used. So we had to add a test for this scenario in generation.
Changelog
Test plan