Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] When initializing model_engine, if an mpu is specified, it can lead to an excessively large checkpoint size, and the checkpoint may not be convertible through the zero_to_fp32.py script. #5514

Open
Kwen-Chen opened this issue May 9, 2024 · 0 comments
Labels
bug Something isn't working training

Comments

@Kwen-Chen
Copy link
Contributor

Describe the bug
When initializing model_engine, if an mpu (model parallelism unit) is specified, it can lead to an excessively large checkpoint size, and the checkpoint may not be convertible through the zero_to_fp32.py script.

To Reproduce
Steps to reproduce the behavior:

  1. The 'mpu' has utilized the provided code mpu
  2. When initializing the 'mpu' and specifying the 'mpu' during the initialization of 'model_engine', without performing any other operations, directly saving the checkpoint, the issue can be reproduced with the following code.
import argparse
import deepspeed
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import mpu
import transformers

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument("--local_rank", type=int, default=-1,
                    help="Reserved for deepspeed framework")

parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
deepspeed.init_distributed()


def main(args):

    mpu.initialize_model_parallel(sequence_parallel_size=4)
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.bfloat16,
        _attn_implementation="flash_attention_2",
    )

    model.gradient_checkpointing_enable()

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    model_engine, _, _, _ = deepspeed.initialize(args=args,
                                                 model=model,
                                                 mpu=mpu,
                                                 optimizer=None,
                                                 model_parameters=model_parameters)

    model_engine.save_checkpoint(f"checkpoints", tag=f"checkpoint-{0}", client_state={})

if __name__ == "__main__":
    main(args)
  1. The script being run is as follows:
deepspeed --include localhost:0,1,2,3 \
 --master_port=25640 \
    train_pt.py \
    --model ~/work/Llama-2-7b-hf  \
	--deepspeed --deepspeed_config config/deepspeed.json 
  1. Using the aforementioned code, the model's checkpoints will be saved as very large files, as shown in the figure below, which is the result of saving for llama2-7B:
    image
    One can observe that the saved size is exactly four times the normal size (sequence_parallel_size)
  2. When I want to run the script python zero_to_fp32.py . model.bin, the error occurs as follows:
[2024-05-08 14:16:02,836] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Processing zero checkpoint './checkpoint-0'
Detected checkpoint of type zero stage 3, world_size: 4
Parsing checkpoint created by deepspeed==0.14.0
Traceback (most recent call last):
  File "/u01/chenkun/work/ring_dp_train/checkpoints/zero_to_fp32.py", line 601, in <module>
    convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
  File "/u01/chenkun/work/ring_dp_train/checkpoints/zero_to_fp32.py", line 536, in convert_zero_checkpoint_to_fp32_state_dict
    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
  File "/u01/chenkun/work/ring_dp_train/checkpoints/zero_to_fp32.py", line 521, in get_fp32_state_dict_from_zero_checkpoint
    return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
  File "/u01/chenkun/work/ring_dp_train/checkpoints/zero_to_fp32.py", line 217, in _get_fp32_state_dict_from_zero_checkpoint
    return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
  File "/u01/chenkun/work/ring_dp_train/checkpoints/zero_to_fp32.py", line 464, in _get_fp32_state_dict_from_zero3_checkpoint
    _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
  File "/u01/chenkun/work/ring_dp_train/checkpoints/zero_to_fp32.py", line 446, in _zero3_merge_trainable_params
    raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
ValueError: consumed 6738415616 numels out of 26953662464 - something is wrong
  1. Although this checkpoint can't convert to the model.bin, it can load by model_engine.load_checkpoints()

Expected behavior
The model is saved at the normal size and is convertible.

ds_report output
image

System info (please complete the following information):

  • OS: [Ubuntu 22.04]
  • GPU count and types [one machines with x4 A100s ]
  • Python version 3.10

Launcher context

deepspeed --include localhost:0,1,2,3 \
 --master_port=25640 \
    train_pt.py \
    --model ~/work/Llama-2-7b-hf  \
	--deepspeed --deepspeed_config config/deepspeed.json 

Additional context
The DeepSpeed config is as follows:

{
  "bf16": {
    "enabled": true
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 2e-5,
      "warmup_max_lr": 2e-5,
      "warmup_num_steps": 0,
      "warmup_type": "linear"
    }
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": 2e-5,
      "betas": [0.9, 0.95],
      "eps": 1e-8,
      "weight_decay": 0.1
    }
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  },
  "gradient_accumulation_steps": 1,
  "steps_per_print": 2000,
  "train_micro_batch_size_per_gpu": 1,
  "wall_clock_breakdown": false
}
@Kwen-Chen Kwen-Chen added bug Something isn't working training labels May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant