Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free generation utils in Eleuther #975

Merged
merged 13 commits into from May 19, 2024
Merged

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented May 13, 2024

Context

In order to support free generation tests like GSM8K - which is part of the suite comprising the OpenLLM Leaderboard - we need to implement our own _model_generate function. This PR adds that capability and tests on GSM8K.

Why do you setup caches before each generate call? Doesn't that somewhat defeat the purpose? Great question, dear reader! A total valid approach would be to call setup_caches right after model instantiation and then calling reset_caches after each batched generate call. However, this would also mean that the default model would have kv_caches enabled and these would be used for the normal _model_call. This is overkill and also requires some additional logic b/c if the batch size changes (common for the last batch in a dataset), we have to call setup_caches again. Therefore, it's simpler in code to just call setup_caches before each generate call. This still provides a performance benefit as we are not recomputing kv across every seq for every batch.

Why is there a discrepancy between Eleuther's Phi3 eval on GSM8K and ours? We already acknowledged there were slight computational differences in our Phi3 implementation and HF's due to the split QKV dimension. I think this is the closest we can get, but I'll check that it looks reasonable to the Eleuther team. If this is a common thing, we should maybe consider defaulting to doing a delta calculation between the original model and the user's finetuned model. This would be the most accurate way to represent how the finetuning went and wouldn't confuse the users as to why the numbers don't match Eleuther or OpenLLM Leaderboard. The downside is this would take twice the amount of time. Curious to hear thoughts from @ebsmothers and @kartikayk.

Changelog

  • Change default seed
  • Add _model_generate function
  • Add tok_batch_encode function, which is necessary for _model_generate, which takes in batched generate requests
  • Add a batch_size param to the config

Test plan

Eleuther Eval results (in 27m 28 s):

(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (add-free-gen)]$ lm_eval --model hf     --model_args pretrained=microsoft/Phi-3-mini-4k-instruct,trust_remote_code=True     --tasks gsm8k     --device cuda:0     --batch_size 8
2024-05-13:16:43:32,891 INFO     [__main__.py:251] Verbosity set to INFO
2024-05-13:16:43:37,247 INFO     [__main__.py:335] Selected Tasks: ['gsm8k']
2024-05-13:16:43:37,247 INFO     [__main__.py:336] Loading selected tasks...
2024-05-13:16:43:37,249 INFO     [evaluator.py:131] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
2024-05-13:16:43:39,230 INFO     [huggingface.py:162] Using device 'cuda:0'
/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
2024-05-13:16:43:40,018 WARNING  [modeling_phi3.py:62] `flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
2024-05-13:16:43:40,018 WARNING  [modeling_phi3.py:66] Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:12<00:00,  6.28s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-05-13:16:43:56,279 INFO     [task.py:395] Building contexts for gsm8k on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:07<00:00, 166.52it/s]
2024-05-13:16:44:04,230 INFO     [evaluator.py:362] Running generate_until requests
Running generate_until requests:   0%|                                                                                                                                                                                      | 0/1319 [00:00<?, ?it/s]2024-05-13:16:44:06,714 WARNING  [logging.py:329] You are not running the flash-attention implementation, expect numerical differences.
Running generate_until requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [27:28<00:00,  1.25s/it]
hf (pretrained=microsoft/Phi-3-mini-4k-instruct,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 8
|Tasks|Version|     Filter     |n-shot|  Metric   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|-----:|---|-----:|
|gsm8k|      3|strict-match    |     5|exact_match|0.7710|±  |0.0116|
|     |       |flexible-extract|     5|exact_match|0.7748|±  |0.0115|

Our results (25m 10 s):

(joe-torchtune) [jrcummings@devgpu012.cln5 ~/projects/joe-torchtune (add-free-gen)]$ CUDA_VISIBLE_DEVICES=7 tune run eleuther_eval --config eleuther_evaluation device=cuda batch_size=8
2024-05-17:07:41:20,037 INFO     [_utils.py:34] Running EleutherEvalRecipe with resolved config:

batch_size: 8
checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: ./phi3
  checkpoint_files:
  - model-00001-of-00002.safetensors
  - model-00002-of-00002.safetensors
  model_type: PHI3_MINI
  output_dir: lora-phi3-math
device: cuda
dtype: bf16
limit: null
max_seq_length: 4096
model:
  _component_: torchtune.models.phi3.phi3_mini
quantizer: null
resume_from_checkpoint: false
seed: 1234
tasks:
- gsm8k
tokenizer:
  _component_: torchtune.models.phi3.phi3_mini_tokenizer
  path: ./phi3/tokenizer.model

2024-05-17:07:41:28,693 DEBUG    [seed.py:59] Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
2024-05-17:07:41:31,719 INFO     [eleuther_eval.py:221] Model is initialized with precision torch.bfloat16.
2024-05-17:07:41:31,734 INFO     [eleuther_eval.py:205] Tokenizer is initialized from file.
2024-05-17:07:41:32,203 INFO     [huggingface.py:162] Using device 'cuda:0'
/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
2024-05-17:07:41:43,105 INFO     [eleuther_eval.py:244] Running evaluation on ['gsm8k'] tasks.
2024-05-17:07:41:43,106 INFO     [task.py:395] Building contexts for gsm8k on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:07<00:00, 166.27it/s]
2024-05-17:07:41:51,071 INFO     [evaluator.py:362] Running generate_until requests
Running generate_until requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [25:20<00:00,  1.15s/it]
2024-05-17:08:07:22,979 INFO     [eleuther_eval.py:251] Eval completed in 1550.94 seconds.
2024-05-17:08:07:22,979 INFO     [eleuther_eval.py:253] gsm8k: {'exact_match,strict-match': 0.7763457164518575, 'exact_match_stderr,strict-match': 0.011477795578836127, 'exact_match,flexible-extract': 0.7824109173616376, 'exact_match_stderr,flexible-extract': 0.011365231761189577, 'alias': 'gsm8k'}

Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/975

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 67b3fcf with merge base f3611e5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 13, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 32 lines in your changes are missing coverage. Please review.

Project coverage is 26.74%. Comparing base (cb8e65a) to head (444448e).
Report is 2 commits behind head on main.

Files Patch % Lines
recipes/eleuther_eval.py 0.00% 32 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #975       +/-   ##
===========================================
- Coverage   67.10%   26.74%   -40.36%     
===========================================
  Files         174      174               
  Lines        7423     7451       +28     
===========================================
- Hits         4981     1993     -2988     
- Misses       2442     5458     +3016     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings joecummings changed the title [WIP] Free generation utils in Eleuther Free generation utils in Eleuther May 14, 2024
@joecummings joecummings marked this pull request as ready for review May 14, 2024 15:54
with context.device:
self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype)

temperature = generation_kwargs.get("temperature", 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a reasonable default for temperature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the MOST common default for temperature b/c it greedily takes the best token.

@@ -53,13 +54,19 @@ def __init__(
*,
device: torch.device,
max_seq_length: int = 4096,
batch_size: int = 32,
batch_size: int = 8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? Is it for memory reasons?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

@@ -31,6 +31,7 @@ seed: 217
tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096
batch_size: 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry this should match the default

@joecummings joecummings merged commit 4203311 into pytorch:main May 19, 2024
29 checks passed
@joecummings joecummings deleted the add-free-gen branch May 19, 2024 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants