Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load Transformer model once using FSDP #1179

Open
ToddMorrill opened this issue Aug 1, 2023 · 0 comments
Open

How to load Transformer model once using FSDP #1179

ToddMorrill opened this issue Aug 1, 2023 · 0 comments

Comments

@ToddMorrill
Copy link

馃摎 Documentation

@HamidShojanazeri, I'm following your FSDP example and swapped in a bigger model, google/flan-t5-xxl, and am a little unclear on what happens when the script starts up. I'm running on a server with 8 V100s so I run the launch command as listed in the README.md file:
torchrun --nnodes 1 --nproc_per_node 8 T5_training.py

Next, I was having trouble downloading the model weights because I think with 8 processes, each one was trying to download the weights and they were removing each others' file locks, so I changed the setup_model function so that only rank 0 downloads the weights and then all other processes will read from the local cache.

Finally, my big question for you is - as the setup_model function is currently written, is it fair to say that we're loading a copy of the model weights for every process running (e.g. in my case, 8 processes)? If so, how can we load the model once and broadcast the weights to all other processes? I ask because this will become a blocker at bigger model scales because we'll eventually run out of CPU memory trying to do this.

Here's my modified setup_model function for reference:

def setup_model(model_name, model_max_length=512, cache_dir=None, rank=None):
    # TODO: is this loading the model on all processes?
    # 1) this seems time consuming, and 2) it seems like it would use way too much memory
    # ensure weights are only downloaded by one process
    if rank == 0:
        model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
        # set model_max_length to avoid warnings
        tokenizer =  T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
    dist.barrier()
    if rank != 0:
        model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
        # set model_max_length to avoid warnings
        tokenizer =  T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
    return model, tokenizer

I imagine this all gets easier and more memory efficient once we start saving the model in the formats you've specified in the model_checkpointing directory but we have to get there in the first place.

I should also note, in case it makes a difference, that I'm setting up the distributed process group (within T5_training.py) before calling setup_model, whereas you call setup_model before setting up the distributed process group in your example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant