Skip to content

Commit

Permalink
Register weights as a non-persistent buffer of `SinusoidalPositiona…
Browse files Browse the repository at this point in the history
…lEmbedding` (#5213)
  • Loading branch information
MaigoAkisame committed Jun 23, 2023
1 parent a29952c commit 31fba01
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ def convert_to_pipeline_parallel_state_dict(self, state_dict):
# fmt: off
if isinstance(module, TransformerEncoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
if isinstance(module, TransformerEncoderLayer):
for suffix in encoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
Expand All @@ -456,7 +455,6 @@ def convert_to_pipeline_parallel_state_dict(self, state_dict):
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
if isinstance(module, TransformerDecoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
if isinstance(module, TransformerDecoderOutputLayer):
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
# fmt: on
Expand Down Expand Up @@ -741,14 +739,6 @@ def buffered_future_mask(self, tensor):

def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)

for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
Expand Down
6 changes: 0 additions & 6 deletions fairseq/models/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,6 @@ def max_positions(self):
return self.max_positions

def upgrade_state_dict_named(self, state_dict, name):
if isinstance(
self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
):
state_dict[
name + ".sentence_encoder.embed_positions._float_tensor"
] = torch.FloatTensor(1)
if not self.load_softmax:
for k in list(state_dict.keys()):
if (
Expand Down
8 changes: 0 additions & 8 deletions fairseq/models/transformer/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,6 @@ def buffered_future_mask(self, tensor):

def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)

if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
Expand Down
8 changes: 0 additions & 8 deletions fairseq/models/transformer/transformer_decoder_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,6 @@ def extract_features_scriptable(

def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)

if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
Expand Down
8 changes: 0 additions & 8 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,6 @@ def max_positions(self):

def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
Expand Down
24 changes: 15 additions & 9 deletions fairseq/modules/sinusoidal_positional_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.onnx.operators
from fairseq import utils
from torch import Tensor, nn
from torch import nn, Tensor


class SinusoidalPositionalEmbedding(nn.Module):
Expand All @@ -22,16 +22,23 @@ def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
self.weights = SinusoidalPositionalEmbedding.get_embedding(
self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.onnx_trace = False
self.register_buffer("_float_tensor", torch.FloatTensor(1))
), persistent=False)
self.max_positions = int(1e5)
self.onnx_trace = False

def prepare_for_onnx_export_(self):
self.onnx_trace = True

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Ignore some deprecated keys that were used in older versions
deprecated_keys = ["weights", "_float_tensor"]
for key in deprecated_keys:
if prefix + key in state_dict:
del state_dict[prefix + key]
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
Expand Down Expand Up @@ -68,12 +75,11 @@ def forward(
bspair = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
if max_pos > self.weights.size(0):
# expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(self._float_tensor)
).to(self.weights)

if incremental_state is not None:
# positions is the same for every token when decoding a single step
Expand Down

0 comments on commit 31fba01

Please sign in to comment.