-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pipelining] add back support for multi-use parameters/buffers #126653
Conversation
Resolves #126626 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126653
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 89de9ec with merge base 5ea956a (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Which titan issue was this addressing? something with freqs_cis? |
See #126626. I filed it against pytorch rather than titan.
|
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) | ||
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) | ||
self.lin1 = torch.nn.Linear(d_hid, d_hid) | ||
self.lin2 = torch.nn.Linear(d_hid, d_hid) | ||
|
||
def forward(self, x, y): | ||
x = torch.mm(x, self.mm_param0) | ||
x = torch.mm(x, self.mm_param1) # mutli-use param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo. (again, typo below)
logger.info( | ||
f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 | ||
) | ||
for user in node.users: | ||
assert user.op == "call_module" | ||
# Move parameter into submodule | ||
move_param_to_callee( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this affect the fqn of the shared parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This PR targets parameters (single FQN) used by multiple stages once the original model is split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_connection = x | ||
x = x + y | ||
x = torch.relu(x) | ||
pipe_split() | ||
x = torch.mm(x, self.mm_param1) | ||
x = torch.mm(x, self.mm_param1) # mutli-use param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have tests that verify fqn sanity (perhaps you added them along with unflattener)?
it'd be nice to confirm that when using multi-use param, the model's state_dict is clean and only has the original copy so checkpoint save/load will work as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh, we don't have support for multi-use param in training yet. Because that would require an all-reduce between the multiple copies of that param, before the next batch forward happens. So, it would be kind of early to talk about how to save them before we can train them :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But, multi-use buffer (as in titan case) and multi-user param in inference are different stories, they can be supported today.
callee = root.get_submodule(callee_name) | ||
assert not hasattr( | ||
callee, param_fqn | ||
), f"Module {callee_name} already has a parameter named {param_fqn}" | ||
|
||
# Assign the parameter to the submodule | ||
if is_buffer: | ||
_assign_attr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im kinda confused though, how come we can assign the attr to a submodule and not cause fqn duplication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are moving the attr to the submodule.
The original attr will be removed IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes make sense, and stacked tests seem to work well
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
…dules" This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)): "Module object has no attributed items." The reason is, a split `ModuleDict` is no longer a `ModuleDict`. (Future support is not guaranteed.) It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules. [ghstack-poisoned]
This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)): "Module object has no attributed items." The reason is, a split `ModuleDict` is no longer a `ModuleDict`. (Future support is not guaranteed.) It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #362 * __->__ #371 This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)): "Module object has no attributed items." The reason is, a split `ModuleDict` is no longer a `ModuleDict`. It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules.
…ch#126653) ## Motivation Resolves pytorch#126626 to support TorchTitan. With this PR, we add back support for cases where a parameter or buffer is used in multiple stages. An example of such usage is in LLaMA (torchtitan), code snippet: ``` for layer in self.layers.values(): h = layer(h, self.freqs_cis) ``` ## Solution Step 1: Remove the previous guards of `if len(node.users) == 1`. Step 2: Call `move_param_to_callee` multiple times, one for each stage ("callee"). Step 3: Delay deletion of the `get_attr` node (for getting the param) from root till this param has been sunk into each stage that uses it. The PR also cleans up the old code around this (dropping the TRANSMIT mode and supporting REPLICATE mode only). ## Test Changed the `ExampleCode` model to use `mm_param1` in multiple stages. Pull Request resolved: pytorch#126653 Approved by: https://github.com/pianpwk
…dules" This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)): "Module object has no attributed items." The reason is, a split `ModuleDict` is no longer a `ModuleDict`. It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules. [ghstack-poisoned]
This PR fixes the issue mentioned [here](pytorch/pytorch#126653 (comment)): "Module object has no attributed items." The reason is, a split `ModuleDict` is no longer a `ModuleDict`. It would be more generally applicable if we use `named_children()` and `register_module()` to access and update submodules. [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Motivation
Resolves #126626 to support TorchTitan.
With this PR, we add back support for cases where a parameter or buffer is used in multiple stages. An example of such usage is in LLaMA (torchtitan), code snippet:
Solution
Step 1:
Remove the previous guards of
if len(node.users) == 1
.Step 2:
Call
move_param_to_callee
multiple times, one for each stage ("callee").Step 3:
Delay deletion of the
get_attr
node (for getting the param) from root till this param has been sunk into each stage that uses it.The PR also cleans up the old code around this (dropping the TRANSMIT mode and supporting REPLICATE mode only).
Test
Changed the
ExampleCode
model to usemm_param1
in multiple stages.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k