Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register weights as a non-persistent buffer of SinusoidalPositionalEmbedding #5213

Merged
merged 1 commit into from
Jun 23, 2023

Conversation

MaigoAkisame
Copy link
Contributor

@MaigoAkisame MaigoAkisame commented Jun 22, 2023

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    Discussed this in a GitHub issue of the pytorch repo.

  • Did you read the contributor guideline?

  • Did you make sure to update the docs?
    Not applicable

  • Did you write any new necessary tests?
    No. But I've tested deeplearning/projects/fairseq-py:test_cpu in Meta's fbcode repo, and this diff does not introduce any new test failures.

What does this PR do?

The module SinusoidalPositionalEmbedding has the problem that its weights attribute is not moved to CPU or CUDA when the module is moved.

Registering weights as a buffer solves the problem.
This also eliminates the need for the buffer _float_tensor, which is used to keep track of whether the module is on CPU or CUDA.

Making weights a non-persistent buffer means it won't be saved to or loaded from a state_dict.

With the changes in this diff, the state_dict of a SinusoidalPositionalEmbedding module should contain neither weights or _float_tensor.
This diff ignores them by overriding the _load_from_state_dict method of the SinusoidalPositionalEmbedding module, instead of duplicating the code in many upgrade_state_dict functions.

TO DISCUSS: Is it OK for me to override _load_from_state_dict? It's a private function, but I see people have overriden it in many places, including in fairseq:
https://github.com/search?q=super()._load_from_state_dict&type=code
https://github.com/search?q=repo%3Afacebookresearch%2Ffairseq%20super()._load_from_state_dict&type=code

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@@ -441,7 +441,6 @@ def convert_to_pipeline_parallel_state_dict(self, state_dict):
# fmt: off
if isinstance(module, TransformerEncoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _float_tensor buffer was just a dummy tensor to track whether the SinusoidalPositionalEmbedding object is on CPU or CUDA. Now the weights buffer can be moved between CPU and CUDA, we no longer need _float_tensor.

@dianaml0
Copy link
Contributor

Does this keep backwards compatability?

@MaigoAkisame
Copy link
Contributor Author

Does this keep backwards compatability?

Yes. The buffer weights is created upon construction of a SinusoidalPositionalEmbedding object; it doesn't need to be loaded from a state_dict.
With the changes in this PR, no matter whether a state_dict contains the keys weights and _float_tensor or not, these keys will be ignored.

@MaigoAkisame
Copy link
Contributor Author

BTW In the discussion on the pytorch issue, they've said it's OK to override the _load_from_state_dict method.

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@dianaml0 dianaml0 merged commit 31fba01 into facebookresearch:main Jun 23, 2023
1 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants