Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non-incremental decoding + unit test #973

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented May 13, 2024

Context

As documented in #959, enabling the key-value cache for a model in generation was giving reasonable results whereas passing in a model without key-value cache enabled produced junk.

The cause: Although we originally supported incremental decoding and non-incremental decoding, this was lost in a refactor somewhere. The result was that the generate function was always ONLY passing in a single token at a time for context, therefore without a key-value cache to keep those values in memory, it was essentially producing random output.

Didn't we have a test for this?: We technically had a test for the decoder here but this only compares a single forward pass. Values won't diverge until we get into the area where the key-value cache is used. So we had to add a test for this scenario in generation.

Changelog

  • Added support for non-incremental decoding in utils.generate
  • Added a method to detect if kv-cache is enabled on the model
  • Added a generation test for a model with and without kv-cache

Test plan

(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (debug-generate-kv-cache)]$ pytest tests/torchtune/utils/test_generation.py
================================================================== test session starts ===================================================================
platform linux -- Python 3.11.9, pytest-8.2.0, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 8 items

tests/torchtune/utils/test_generation.py ........                                                                                                  [100%]

=================================================================== 8 passed in 1.40s ===================================================================
  1. Test generate recipe with model.setup_caches(batch_size=1)
(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (debug-generate-kv-cache)]$ tune run generate --config generation prompt="Tell me a joke"
2024-05-13:10:49:30,034 INFO     [_utils.py:34] Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: ./phi3
  checkpoint_files:
  - model-00001-of-00002.safetensors
  - model-00002-of-00002.safetensors
  model_type: PHI3_MINI
  output_dir: ./
device: cuda
dtype: bf16
max_new_tokens: 256
model:
  _component_: torchtune.models.phi3.phi3_mini
prompt: Tell me a joke
quantizer: null
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.phi3.phi3_mini_tokenizer
  path: ./phi3/tokenizer.model
top_k: null

2024-05-13:10:49:33,743 DEBUG    [seed.py:59] Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
2024-05-13:10:49:36,605 INFO     [generate.py:76] Model is initialized with precision torch.bfloat16.
2024-05-13:10:49:49,057 INFO     [generate.py:126] Tell me a joke about a doctor.
 Why did the doctor take his medicine?

To feel better and to avoid a check-up!

This joke plays on the dual meaning of "check-up" (medical examination) and the idea that taking medicine helps a person feel better. It's a light-hearted way to poke fun at doctors and their profession.

However, it's important to remember that jokes about healthcare professionals should always be used in good-natured settings, as these individuals work hard to care for others. Here's a friendly doctor joke:

Why did the doctor bring a ladder to the hospital?

Because he wanted to reach new heights in patient care!

This joke is a playful take on the idea that doctors are always striving to improve and excel in their profession, without undermining the importance of their work. It's a lighthearted way to appreciate the dedication and hard work of medical professionals. Why was the computer cold at the hospital?

Because it left its Windows open!

This joke uses a pun on the Windows operating system (
2024-05-13:10:49:49,060 INFO     [generate.py:139] Time for inference: 12.17 sec total, 21.04 tokens/sec
2024-05-13:10:49:49,060 INFO     [generate.py:142] Bandwidth achieved: 196.78 GB/s
2024-05-13:10:49:49,060 INFO     [generate.py:143] Memory used: 9.39 GB
(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (debug-generate-kv-cache)]$
  1. Run generate with L79-80 commented out
(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (debug-generate-kv-cache)]$ tune run generate --config generation prompt="Tell me a joke"
2024-05-13:10:50:35,608 INFO     [_utils.py:34] Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: ./phi3
  checkpoint_files:
  - model-00001-of-00002.safetensors
  - model-00002-of-00002.safetensors
  model_type: PHI3_MINI
  output_dir: ./
device: cuda
dtype: bf16
max_new_tokens: 256
model:
  _component_: torchtune.models.phi3.phi3_mini
prompt: Tell me a joke
quantizer: null
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.phi3.phi3_mini_tokenizer
  path: ./phi3/tokenizer.model
top_k: null

2024-05-13:10:50:40,237 DEBUG    [seed.py:59] Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
2024-05-13:10:50:45,230 INFO     [generate.py:76] Model is initialized with precision torch.bfloat16.
2024-05-13:10:50:57,210 INFO     [generate.py:126] Tell me a joke about a doctor.
 Why did the doctor take his medicine? Because he heard it was the best prescription for a good sense of humor!
 I'm sorry, I'll try to come up with one more. How about: Why did the doctor carry a red pen? In case he needed to write out a prescription for humor!
 Here's a light-hearted one: Why did the doctor become a stand-up comedian? He wanted to make people laugh while they were waiting for their check-ups!
 And one more: Why did the doctor become a magician? Because he wanted to perform a real-life "Hocus Pocus" with his medical skills!
}
 Sure! Here's a doctor-related joke for you:
Why can't you give a doctor a rubber chicken? Because they're always in demand, but you can't trust them to stick around with a rubber chicken!

This joke plays on the double meaning of the phrase "in demand." In the context of a doctor's profession, it refers to their high level of demand due to their medical expertise and the essential nature of their services.
2024-05-13:10:50:57,212 INFO     [generate.py:139] Time for inference: 11.62 sec total, 22.03 tokens/sec
2024-05-13:10:50:57,212 INFO     [generate.py:142] Bandwidth achieved: 170.58 GB/s
2024-05-13:10:50:57,213 INFO     [generate.py:143] Memory used: 7.97 GB
(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (debug-generate-kv-cache)]$

Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/973

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 25b5e63 with merge base dc2b991 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 13, 2024
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change looks good with just some small comments around testing. Do the tests verify this works with bs > 1 too?

recipes/generate.py Outdated Show resolved Hide resolved
tests/torchtune/utils/test_generation.py Outdated Show resolved Hide resolved
@joecummings
Copy link
Contributor Author

The change looks good with just some small comments around testing. Do the tests verify this works with bs > 1 too?

Added this test!

curr_input_pos = input_pos[curr_pos].unsqueeze(0)
else:
curr_input_pos = input_pos[: curr_pos + 1]
tokens = generated_tokens.clone()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure I need this clone...

cc someone who is better at pytorch than me

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, my primary question here would be around the need to integrate non-incremental decoding to generation utils as opposed to just having this as a custom test function. The main intent is to test correctness o KV Caching which makes sense, but not sure if the testing needs to intrude into the generation utils

@joecummings
Copy link
Contributor Author

As discussed offline, my primary question here would be around the need to integrate non-incremental decoding to generation utils as opposed to just having this as a custom test function. The main intent is to test correctness o KV Caching which makes sense, but not sure if the testing needs to intrude into the generation utils

My concern with just re-creating an entire non-incremental decoding generation function in tests/ is that any changes to the actual generate function will need to be incorporated into the testing generation function otherwise the results will diverge. This makes it more confusing to work on generate down the road.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generate with KV-cache enabled vs. not enabled gives different results
4 participants