Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming from checkpoint runs into OOM #30822

Open
2 of 4 tasks
PKlumpp opened this issue May 15, 2024 · 2 comments
Open
2 of 4 tasks

Resuming from checkpoint runs into OOM #30822

PKlumpp opened this issue May 15, 2024 · 2 comments
Labels

Comments

@PKlumpp
Copy link

PKlumpp commented May 15, 2024

System Info

image
Using GPU in script: A100 80 GB; Driver Version: 550.54.15; CUDA-Version: 12.4
Using distributed or parallel setup: No

Who can help?

@ArthurZucker @muellerz @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This is the definition of my custom model:

class CustomLongformer(LongformerModel):

    def __init__(self, config):
        super().__init__(config, add_pooling_layer=False)
        self.linear = nn.Linear(
            in_features=1024,
            out_features=47,
        )
        self.custom_embeddding = nn.Embedding(
            num_embeddings=14,
            embedding_dim=1024,
        )

    def forward(
        self,
        input_ids: Optional = None,
        attention_mask: Optional = None,
        global_attention_mask: Optional = None,  # This attention mask is 1 at time-index 0, and 0 elsewhere (CLS-only global attention)
        inputToEmbedding: Optional = None,
        labels: Optional = None,
    ) -> LongformerSequenceClassifierOutput:
        hidden = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
        ).last_hidden_state
        embeddings = self.custom_embeddding(inputToEmbedding)
        hidden = hidden[:, 0, :]  # Select first token of each sequence (CLS Token)
        hidden = hidden + embeddings  # LM output + folder embedding
        logits = self.linear(hidden)
        if labels is not None:
            loss = F.cross_entropy(
                input=logits,
                target=labels,
            )
            return {
                "loss": loss,
                "logits": logits,
            }
        return {
            "logits": logits,
        }

config = LongformerConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=1024+ 1,
    num_hidden_layers=24,
    num_attention_heads=16,
    intermediate_size=4096,
    hidden_size=1024,
    attention_window=256,
    bos_token_id=tokenizer.cls_token_id,
    eos_token_id=tokenizer.sep_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

model = CustomLongformer(config)

I have a dataset with long texts which are chunked to samples of 1024 tokens (padded to said length if required).
These are my training arguments:

training_args = TrainingArguments(
    output_dir="some/path",
    do_train=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=30,
    per_device_eval_batch_size=30,
    gradient_accumulation_steps=2,
    eval_steps=4_000,
    learning_rate=1e-5,
    logging_steps=20,
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    load_best_model_at_end=True,
    warmup_steps=100,
    save_strategy="steps",
    save_steps=4_000,
    save_total_limit=5,
    fp16=True,
    dataloader_num_workers=6,
    optim="adamw_torch",
    report_to="tensorboard",
 )

Now all that's left is a trainer to start/continue training:

trainer = Trainer(
    model=model,
    data_collator=CustomDataCollator(),  # This only creates training batches (tried with some default as well)
    args=training_args,
    train_dataset=ds["training"],
    eval_dataset=ds["validation"],
)
trainer.train(resume_from_checkpoint=True)

Whenever I start training from scratch (not resuming from checkpoint), everything works fine, and I can train for days. But as soon as I want to start from a checkpoint saved during training, I get an OutOfMemory error. The GPU is not occupied by any other task, and I know for sure that there are no leaks from other processes happening. At the same time, the OOM says that it failed allocating 120 MiB GPU Memory, but in fact, more than 7 GiB are still free according to nvidia-smi.

Expected behavior

Returning from checkpoint should not run into any OOM problems if the model trained successfully before. The expected behavior can be achieved by setting os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128", but this is a) only a hacky solution and b) results in much longer training times.

@PKlumpp
Copy link
Author

PKlumpp commented May 16, 2024

A quick update: If I run a forward/backward pass myself using native torch with the same batch size, everything works fine if I stick to mixed precision (like in the TrainingArguments above):

with autocast(device_type='cuda', dtype=torch.float16):
    y = model(**batch)
    y["loss"].backward()

I only run into an OOM if I omit mixed precision. Maybe it is related to that.

@ambroser53
Copy link

I have this same issue but for me this also includes the use of deepspeed. The OOM is not on the GPU but in the system memory because of CPU offloading of activation weights. However, just like this issue, the OOM is only when resuming and not when starting training from scratch. I've been digging into this issue for a while on the deepspeed side so if anyone has any further avenues of potential fixes that would be helpful.

I think it might have something to do with this dead deepspeed issue but none of the proposed fixes work.

A quick update: If I run a forward/backward pass myself using native torch with the same batch size, everything works fine if I stick to mixed precision (like in the TrainingArguments above):

with autocast(device_type='cuda', dtype=torch.float16):
    y = model(**batch)
    y["loss"].backward()

I only run into an OOM if I omit mixed precision. Maybe it is related to that.

If I was to use autocast with the HF trainer where do I put it? Like this?

with autocast(device_type='cuda', dtype=torch.float16):
    trainer.train(resume_from_checkpoint=ft_checkpoint_dir)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants