Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving memory efficiency further 🚀 #30860

Open
Cyrilvallez opened this issue May 16, 2024 · 3 comments
Open

Improving memory efficiency further 🚀 #30860

Cyrilvallez opened this issue May 16, 2024 · 3 comments

Comments

@Cyrilvallez
Copy link

Cyrilvallez commented May 16, 2024

Feature request

Removing the line logits = logits.float() in most ModelForCausalLM. This would allow to save a lot of memory for models with large vocabulary size. This allows to divide the memory peak by more than 2 on Llama3.

Motivation

This is in relation to my work in #30536.
I noticed that almost all ModelForCausalLM contain the following line in the forward:

logits = logits.float()

Now, since most models are now used in (b)float16, or even quantized, that line will almost always double the memory footprint of the logits. As the vocabulary size can be quite big (e.g. Llama3), this result in a lot of memory being used.
I suspect that it was originally introduced so that later manipulations of the logits (processors, warpers...) can be applied without losing too much precision. However, in generate() we only ever use the last token logits, not the whole logit matrix. So this is a huge waste of memory.

Your contribution

If the casting of the logits to float is indeed only used for not losing precision in their manipulations, I propose to only cast the last token to float in each decoding strategy function.

So, instead of:

logits = logits.float()

in forward(), do

next_token_logits = outputs.logits[:, -1, :].clone().float()

in each decoding strategy function. It would only cast the last token vector to float which is negligible in term of memory overhead.

As an example of the potential memory gains, running this very simple code snippet on Llama3 8B (vocabulary size 128256):

import torch
from transformers import AutoModelForCausalLM

model_name = 'meta-llama/Meta-Llama-3-8B'
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
                                            torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1)

memory = []
sizes = [100] + [500*i for i in range(1, 14)]
for size in sizes:

    input = torch.randint(1, 120000, (1, size), device=1)

    torch.cuda.reset_peak_memory_stats(1)
    actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3

    # Single forward pass (first iteration of `generate()`)
    with torch.no_grad():
        out = model(input)

    memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
    memory.append(memory_used)

    del out

gives:
llama3_example.pdf
llama3_ratio_example.pdf

That is, more than dividing by 2 the memory footprint. This is because the vocabulary size is so large that computing the logits from the hidden states is actually more costly than computing the hidden states themselves. Thus when casting to float(), we more than double the memory requirements (double for the new logits + the overhead when actually copying).

Of course, other models usually have smaller vocabulary size so will not benefit as much, but still the memory peak will decrease by a non-negligible portion for all applicable models (see below for Mistral, ~30% memory gain). And Llama3, which is I believe the hottest open-source model at the moment will be much more efficient.
mistral_ratio_example.pdf

Of course, if this casting to float is made for something else that I overlooked, this may not be applicable. Otherwise, I would be happy to make the change.

@ArthurZucker @gante

Cheers,
Cyril

@ArthurZucker
Copy link
Collaborator

That's actually something we should really do, in the light of #29943 which has this:

hidden_states = outputs[0]
if num_logits_to_keep is None:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
logits = logits.float()

@ArthurZucker
Copy link
Collaborator

(clone is missing)

@gante
Copy link
Member

gante commented May 21, 2024

@Cyrilvallez

However, in generate() we only ever use the last token logits, not the whole logit matrix.

This is true except in assisted generation, where we want the logits for all candidate tokens 😛 But we can generalize to "we only ever want as many logits as input tokens".


👉 regarding keeping all the logits at prefill time: in our generate refactor plans, we will be separating the prefill stage. The prefill stage is meant to compute the KV caches without returning the associated logits, so it should solve this part of the problem

👉 regarding casting the logits with .float(): I agree we should move this part of the logic to generate. It would be a breaking PR (because it changes the type of an output), but it could save considerable memory at prefill time. Even after separating the prefill stage (see above), the upcast of the large logits tensor would still happen in forward. No upcast = no need to materialize the FP32 prefill logits tensor = memory savings. @ArthurZucker are you okay with this breaking change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants