Skip to content

Commit

Permalink
(2/n) Support 2D Parallelism - Distributed Checkpoints (#19852)
Browse files Browse the repository at this point in the history
* distributed checkpoints

* use decorator

* refactor if-strict

* update example

* filter non-persistent buffers (todo, add test)

* simplify checkpoint loading for model
  • Loading branch information
awaelchli committed May 15, 2024
1 parent 90d04b5 commit 9455871
Show file tree
Hide file tree
Showing 11 changed files with 682 additions and 127 deletions.
5 changes: 2 additions & 3 deletions examples/fabric/tensor_parallel/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs

# Parallelize the first embedding and the last linear out projection
plan = {
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
Expand Down
10 changes: 5 additions & 5 deletions examples/fabric/tensor_parallel/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def train():

fabric.print(f"Number of model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f} B")

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True)

# Set up model and optimizer
model, optimizer = fabric.setup(model, optimizer)

model = fabric.setup(model)
model.init_weights()

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True)
optimizer = fabric.setup_optimizers(optimizer)

# Define dataset/dataloader
dataset = RandomTokenDataset(vocab_size=model_args.vocab_size, seq_length=128)
dataloader = DataLoader(dataset, batch_size=8)
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))

-
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852))


### Changed

Expand Down
38 changes: 8 additions & 30 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def clip_gradients_norm(
# the root must be wrapped
raise TypeError(
"Gradient clipping with FSDP is only possible if the module passed to"
f" `{self.__class__.__name__}.clip_gradients_norm` is wrapped in `FullyShardedDataParallel`."
f" `{type(self).__name__}.clip_gradients_norm` is wrapped in `FullyShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
self.precision.unscale_gradients(optimizer)
Expand Down Expand Up @@ -506,12 +506,7 @@ def load_checkpoint(
state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.
The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a
directory of multiple files rather than a single file.
"""
"""Load the contents from a checkpoint and restore the state of the given objects."""
if not state:
raise ValueError(
f"Got FSDPStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
Expand All @@ -522,6 +517,8 @@ def load_checkpoint(
path = Path(self.broadcast(path))

if isinstance(state, Module):
from lightning.fabric.strategies.model_parallel import _load_raw_module_state_from_path

_load_raw_module_state_from_path(path, module=state, world_size=self.world_size, strict=strict)
return {}

Expand Down Expand Up @@ -592,6 +589,9 @@ def load_checkpoint(

if _is_full_checkpoint(path):
checkpoint = _lazy_load(path)

from lightning.fabric.strategies.model_parallel import _load_raw_module_state

_load_raw_module_state(checkpoint.pop(module_key), module=module, world_size=self.world_size, strict=strict)

if isinstance(state, Module):
Expand Down Expand Up @@ -755,7 +755,7 @@ def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
# the root must be wrapped
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."
f" `{type(self).__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
return module.no_sync()
Expand Down Expand Up @@ -848,28 +848,6 @@ def _has_fsdp_modules(module: object) -> TypeGuard[Module]:
return isinstance(module, Module) and any(isinstance(m, FullyShardedDataParallel) for m in module.modules())


def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int, strict: bool = True) -> None:
"""Loads the state dict from a file path into the FSDP module."""
if not _is_full_checkpoint(path):
raise ValueError(
"Failed to load checkpoint directly into the model. The given path must be a single file containing the"
f" full state dict: {path}"
)
# Use `lazy_load` instead of `torch.load` here to avoid storing a copy of the full checkpoint per rank
_load_raw_module_state(state_dict=_lazy_load(path), module=module, world_size=world_size, strict=strict)


def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
"""Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if not isinstance(module, FSDP):
module.load_state_dict(state_dict, strict=strict)
else:
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
module.load_state_dict(state_dict, strict=strict)


def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None:
# FSDP doesn't move modules without parameters (e.g. Metrics) to the device
# https://github.com/pytorch/pytorch/issues/113113
Expand Down

0 comments on commit 9455871

Please sign in to comment.