Register weights
as a non-persistent buffer of SinusoidalPositionalEmbedding
#5213
+15
−49
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Before submitting
Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
Discussed this in a GitHub issue of the pytorch repo.
Did you read the contributor guideline?
Did you make sure to update the docs?
Not applicable
Did you write any new necessary tests?
No. But I've tested
deeplearning/projects/fairseq-py:test_cpu
in Meta's fbcode repo, and this diff does not introduce any new test failures.What does this PR do?
The module
SinusoidalPositionalEmbedding
has the problem that itsweights
attribute is not moved to CPU or CUDA when the module is moved.Registering
weights
as a buffer solves the problem.This also eliminates the need for the buffer
_float_tensor
, which is used to keep track of whether the module is on CPU or CUDA.Making
weights
a non-persistent buffer means it won't be saved to or loaded from astate_dict
.With the changes in this diff, the
state_dict
of aSinusoidalPositionalEmbedding
module should contain neitherweights
or_float_tensor
.This diff ignores them by overriding the
_load_from_state_dict
method of theSinusoidalPositionalEmbedding
module, instead of duplicating the code in manyupgrade_state_dict
functions.TO DISCUSS: Is it OK for me to override
_load_from_state_dict
? It's a private function, but I see people have overriden it in many places, including in fairseq:https://github.com/search?q=super()._load_from_state_dict&type=code
https://github.com/search?q=repo%3Afacebookresearch%2Ffairseq%20super()._load_from_state_dict&type=code
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃