Skip to content

Codes and Data for Scaling Relationship on Learning Mathematical Reasoning with Large Language Models

Notifications You must be signed in to change notification settings

OFA-Sys/gsm8k-ScRel

Repository files navigation

Scaling Relationship on Learning Mathematical Reasoning with Large Language Models

The code and data used for reproducing results of Scaling Relationship on Learning Mathematical Reasoning with Large Language Models and Query and Response Augmentation Cannot Help Out-of-domain Math Reasoning Generalization.

  • [2023.10] We have a new paper that investigates the scaling of in-domain and out-of-domain generalization on augmented math problems.
  • [2023.9] Paper updated with more details on 65B and 70B models.
Setting 7B 7B-2 13B 13B-2 33B 65B 70B-2
ICL-8shot 11.0/18.1 14.6/- 17.8/29.3 28.7/- 35.6/53.1 50.9/69.7 56.8/-
SFT 35.9/48.7 41.6/55.4 43.0/55.2 50.0/61.7 54.6/- 59.3/- 63.2/-
RFT k=100 41.7/52.7 47.5/58.7 49.1/59.9 54.8/65.4 54.5/- - -
RFT-U13B 49.3/61.8 50.3/65.6 52.1/66.2 55.4/69.1 56.5/- 59.0/- 62.3/-
RFT-U33B 49.1/61.6 51.2/64.1 51.4/66.3 55.3/69.1 57.9/- 59.7/- 64.8/-

Metrics are maj1@1 and maj1@100.

Findings from the paper

fig1 fig2

SFT Training

If you cannot reproduce our results, please try using Transformers <= 4.29 and test with batch size=1.

Use train_xb.sh for fine-tuning LLaMA and LLaMA-2.

bash train_xb.sh ./data/train_use.jsonl SAVE_PATH 3

RFT Inference

LLaMA 7B / 13B

bash group_sample_7b_13b.sh SAVE_PATH

LLaMA 30B

bash group_sample_30b.sh SAVE_PATH

Filter reasoning paths

python collect_rejection_sampling.py

RFT Training

For RFT using LLaMA-7B/7B-2/13B/13B-2/33B generated samples with k=100.

bash train_xb.sh ./data/rft/llama_yb.jsonl SAVE_PATH 3

For RFT using U13B.

bash train_xb.sh ./data/rft/u13b.jsonl SAVE_PATH 3

For RFT using U33B.

bash train_xb.sh ./data/rft/u33b.jsonl SAVE_PATH 3

Evaluation

We use greedy decode for the test set.

For evaluate 7B/13B models:

bash test_7b_13b.sh SAVE_PATH

For evaluate 30B models:

bash single_test_30b.sh SAVE_PATH 0 ./data/test_jsonl.sh

For evaluate 65B / 70B models:

bash single_test_65b.sh SAVE_PATH 0,1 ./data/test_jsonl.sh

Use eval.py to obtain the scores, and it also supports maj1@K.

GPU Usage

7B / 13B 33B 65B / 70B
SFT / RFT 8 16 32
Minimal Inference 1 1 2
Group Inference 8 8 8

Checkpoints

7B 7B2 13B 13B2 33B
RFT k = 100 OFA-Sys/gsm8k-rft-llama7b-sample100
RFT U13B OFA-Sys/gsm8k-rft-llama7b-u13b OFA-Sys/gsm8k-rft-llama7b2-u13b OFA-Sys/gsm8k-rft-llama13b-u13b OFA-Sys/gsm8k-rft-llama13b2-u13b
RFT U33B OFA-Sys/gsm8k-rft-llama33b-u33b

Query and Response Augmentation Cannot Help Out-of-domain Math Reasoning Generalization

Model Details

MuggleMATH is fully fine-tuned on the AugGSM8K and AugMATH datasets(https://github.com/OFA-Sys/gsm8k-ScRel/tree/main/data/MuggleMATH) and based on the LLaMA-2 Models.

Model Usage

prompting template: ''' "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ''' We recommend using vllm to accelerate inference.

Experiment

Model GSM8K MATH
MuggleMATH-7B 69.8 25.8
MuggleMATH-13B 74.3 30.7
MuggleMATH-70B 82.5 42.1

Checkpoints

Model Checkpoints
MuggleMATH-7B https://huggingface.co/OFA-Sys/MuggleMath_7B
MuggleMATH-13B https://huggingface.co/OFA-Sys/MuggleMath_13B
MuggleMATH-70B https://huggingface.co/OFA-Sys/MuggleMath_70B

Citation

@misc{yuan2023scaling,
      title={Scaling Relationship on Learning Mathematical Reasoning with Large Language Models}, 
      author={Zheng Yuan and Hongyi Yuan and Chengpeng Li and Guanting Dong and Keming Lu and Chuanqi Tan and Chang Zhou and Jingren Zhou},
      year={2023},
      eprint={2308.01825},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{li2023query,
      title={Query and Response Augmentation Cannot Help Out-of-domain Math Reasoning Generalization}, 
      author={Chengpeng Li and Zheng Yuan and Guanting Dong and Keming Lu and Jiancan Wu and Chuanqi Tan and Xiang Wang and Chang Zhou},
      year={2023},
      eprint={2310.05506},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}