PyTorch Distributed Load Updates or Returns state_dict
#125096
Labels
module: distributed_checkpoint
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 The feature, motivation and pitch
Torch distributed checkpoint
load_state_dict
(https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_loader.py#L20)updates the passed in
state_dict
(and returns it). This function is deprecated in torch 2.3 in favor ofload
(https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_loader.py#L48) which neither returns nor updates the passed instate_dict
. Instead, it only callsload_state_dict
on Stateful elements in the specifiedstate_dict
.Unfortunately, this new API is greatly limiting. For example, in Composer's state_dict passed for checkpointing, we also store various RNG tensors in a dict for determinism. In order to use the new API, we have to rewrap everything in a Stateful class, which is a somewhat pointless abstraction. Instead, we prefer to receive a loaded
state_dict
and then manually callload_state_dict
on appropriate subitems.Can we modify
load
to update the passed instate_dict
? This would entail adding:after https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_loader.py#L172-L177
Alternatives
No response
Additional context
No response
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC
The text was updated successfully, but these errors were encountered: