You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@HamidShojanazeri, I'm following your FSDP example and swapped in a bigger model, google/flan-t5-xxl, and am a little unclear on what happens when the script starts up. I'm running on a server with 8 V100s so I run the launch command as listed in the README.md file: torchrun --nnodes 1 --nproc_per_node 8 T5_training.py
Next, I was having trouble downloading the model weights because I think with 8 processes, each one was trying to download the weights and they were removing each others' file locks, so I changed the setup_model function so that only rank 0 downloads the weights and then all other processes will read from the local cache.
Finally, my big question for you is - as the setup_model function is currently written, is it fair to say that we're loading a copy of the model weights for every process running (e.g. in my case, 8 processes)? If so, how can we load the model once and broadcast the weights to all other processes? I ask because this will become a blocker at bigger model scales because we'll eventually run out of CPU memory trying to do this.
Here's my modified setup_model function for reference:
def setup_model(model_name, model_max_length=512, cache_dir=None, rank=None):
# TODO: is this loading the model on all processes?
# 1) this seems time consuming, and 2) it seems like it would use way too much memory
# ensure weights are only downloaded by one process
if rank == 0:
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
# set model_max_length to avoid warnings
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
dist.barrier()
if rank != 0:
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
# set model_max_length to avoid warnings
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
return model, tokenizer
I imagine this all gets easier and more memory efficient once we start saving the model in the formats you've specified in the model_checkpointing directory but we have to get there in the first place.
I should also note, in case it makes a difference, that I'm setting up the distributed process group (within T5_training.py) before calling setup_model, whereas you call setup_model before setting up the distributed process group in your example.
The text was updated successfully, but these errors were encountered:
馃摎 Documentation
@HamidShojanazeri, I'm following your FSDP example and swapped in a bigger model,
google/flan-t5-xxl
, and am a little unclear on what happens when the script starts up. I'm running on a server with 8 V100s so I run the launch command as listed in the README.md file:torchrun --nnodes 1 --nproc_per_node 8 T5_training.py
Next, I was having trouble downloading the model weights because I think with 8 processes, each one was trying to download the weights and they were removing each others' file locks, so I changed the
setup_model
function so that only rank 0 downloads the weights and then all other processes will read from the local cache.Finally, my big question for you is - as the
setup_model
function is currently written, is it fair to say that we're loading a copy of the model weights for every process running (e.g. in my case, 8 processes)? If so, how can we load the model once and broadcast the weights to all other processes? I ask because this will become a blocker at bigger model scales because we'll eventually run out of CPU memory trying to do this.Here's my modified
setup_model
function for reference:I imagine this all gets easier and more memory efficient once we start saving the model in the formats you've specified in the model_checkpointing directory but we have to get there in the first place.
I should also note, in case it makes a difference, that I'm setting up the distributed process group (within
T5_training.py
) before callingsetup_model
, whereas you callsetup_model
before setting up the distributed process group in your example.The text was updated successfully, but these errors were encountered: