Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(5/n) Support 2D Parallelism in Lightning Trainer #19878

Merged
merged 9 commits into from
May 17, 2024

Conversation

awaelchli
Copy link
Member

@awaelchli awaelchli commented May 16, 2024

What does this PR do?

Ports the functionality added in #19846 to the Trainer. The same tests are adopted and rewritten for the Trainer semantics. Some tests were also taken from the Trainer FSDP Strategy test files.

To keep the PRs minimal, I'm not including checkpointing here and will submit it in a follow up PR.
A concrete example of how to use the strategy is in #19879. In summary:

# 1. Define your LightningModule normally
class MyParallelModel(LightningModule):

	# 2. Add this hook
    def configure_model(self):
        # 3. Add parallelization here. You can access `self.device_mesh`.
        tp_mesh = self.device_mesh["tensor_parallel"]
        dp_mesh = self.device_mesh["data_parallel"]
		
        plan = {"layer.w1": ColwiseParallel())
        parallelize_module(model, tp_mesh, plan)


# 4. Select the `ModelParallelStrategy` in the Trainer.
from lightning.pytorch.strategies import ModelParallelStrategy

strategy = ModelParallelStrategy()
trainer = Trainer(strategy=strategy) 

# Defaults are:
strategy = ModelParallelStrategy(
	tensor_parallel_size="auto",   # number of GPUs in the machine
    data_parallel_size="auto",  # number of machines in the cluster
)

📚 Documentation preview 📚: https://pytorch-lightning--19878.org.readthedocs.build/en/19878/

cc @Borda @carmocca @justusschock @awaelchli

@awaelchli awaelchli changed the base branch from master to feature/tp-full-optim-load May 16, 2024 20:33
@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels May 16, 2024
@awaelchli awaelchli added this to the 2.3 milestone May 17, 2024
@awaelchli awaelchli added the feature Is an improvement or enhancement label May 17, 2024
@awaelchli awaelchli marked this pull request as ready for review May 17, 2024 09:53
Copy link
Contributor

github-actions bot commented May 17, 2024

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 2.0, oldest) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, lightning, 3.10, 2.2) success
pl-cpu (macOS-14, lightning, 3.10, 2.3) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.3) success
pl-cpu (windows-2022, lightning, 3.8, 2.0, oldest) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.10, 2.2) success
pl-cpu (windows-2022, lightning, 3.10, 2.3) success
pl-cpu (macOS-11, pytorch, 3.8, 2.0) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.0) success
pl-cpu (windows-2022, pytorch, 3.8, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/strategies/__init__.py, src/lightning/pytorch/strategies/model_parallel.py, src/lightning/pytorch/trainer/connectors/accelerator_connector.py, tests/tests_pytorch/strategies/test_model_parallel.py, tests/tests_pytorch/strategies/test_model_parallel_integration.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/pytorch/core/module.py, src/lightning/pytorch/strategies/__init__.py, src/lightning/pytorch/strategies/model_parallel.py, src/lightning/pytorch/trainer/connectors/accelerator_connector.py, tests/tests_pytorch/strategies/test_model_parallel.py, tests/tests_pytorch/strategies/test_model_parallel_integration.py, src/lightning/fabric/strategies/model_parallel.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/strategies/__init__.py, src/lightning/pytorch/strategies/model_parallel.py, src/lightning/pytorch/trainer/connectors/accelerator_connector.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/core/module.py, src/lightning/pytorch/strategies/__init__.py, src/lightning/pytorch/strategies/model_parallel.py, src/lightning/pytorch/trainer/connectors/accelerator_connector.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 2.0, oldest) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, lightning, 3.11, 2.2) success
fabric-cpu (macOS-14, lightning, 3.10, 2.3) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3) success
fabric-cpu (windows-2022, lightning, 3.8, 2.0, oldest) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.11, 2.2) success
fabric-cpu (windows-2022, lightning, 3.11, 2.3) success
fabric-cpu (macOS-11, fabric, 3.8, 2.0) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.0) success
fabric-cpu (windows-2022, fabric, 3.8, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py, tests/tests_fabric/strategies/test_model_parallel.py, tests/tests_fabric/strategies/test_model_parallel_integration.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py, tests/tests_fabric/strategies/test_model_parallel.py, tests/tests_fabric/strategies/test_model_parallel_integration.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/strategies/__init__.py, src/lightning/pytorch/strategies/model_parallel.py, src/lightning/pytorch/trainer/connectors/accelerator_connector.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/strategies/model_parallel.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/strategies/__init__.py, src/lightning/pytorch/strategies/model_parallel.py, src/lightning/pytorch/trainer/connectors/accelerator_connector.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@awaelchli awaelchli changed the base branch from feature/tp-full-optim-load to master May 17, 2024 13:46
Copy link

codecov bot commented May 17, 2024

Codecov Report

Attention: Patch coverage is 88.12785% with 26 lines in your changes are missing coverage. Please review.

Project coverage is 59%. Comparing base (1d0c6aa) to head (8eebff9).

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #19878     +/-   ##
=========================================
- Coverage      84%      59%    -25%     
=========================================
  Files         425      421      -4     
  Lines       35028    35139    +111     
=========================================
- Hits        29319    20714   -8605     
- Misses       5709    14425   +8716     

Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great

src/lightning/pytorch/strategies/model_parallel.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label May 17, 2024
@lantiga lantiga merged commit 32e2418 into master May 17, 2024
116 of 117 checks passed
@lantiga lantiga deleted the feature/tp-pl-strategy branch May 17, 2024 23:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants