Skip to content

Commit

Permalink
fixed RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Gumma committed Apr 23, 2024
1 parent 09b44f6 commit bb64f47
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 20 deletions.
10 changes: 5 additions & 5 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def reduce_metrics(cls, logging_outputs) -> None:
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
lambda meters: (
round(meters["n_correct"].sum * 100.0 / meters["total"].sum, 3)
if meters["total"].sum > 0
else float("nan")
),
)

@staticmethod
Expand Down
8 changes: 0 additions & 8 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,6 @@ class TransformerConfig(FairseqDataclass):
"help": "use learned frequencies for RoPE instead of fixed frequencies"
},
)
rope_use_xpos: Optional[bool] = field(
default=False,
metadata={"help": "decay RoPE similar to ALiBi"},
)
rope_xpos_scale_base: Optional[int] = field(
default=512,
metadata={"help": "base for scaling the positional encoding"},
)
rope_interpolate_factor: Optional[float] = field(
default=1,
metadata={"help": "interpolation factor for RoPE"},
Expand Down
12 changes: 5 additions & 7 deletions fairseq/modules/native_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.nn import Parameter

from fairseq import utils
from einops import rearrange
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from fairseq.modules.multihead_attention import MultiheadAttention
Expand Down Expand Up @@ -39,8 +38,6 @@ def __init__(
qn_block_size=8,
rope=False,
rope_interpolate_factor=1,
rope_use_xpos=False,
rope_xpos_scale_base=512,
rope_learned_freq=False,
):
super().__init__(embed_dim, num_heads, dictionary=dictionary)
Expand Down Expand Up @@ -70,11 +67,9 @@ def __init__(
self.rotary_pos_embed = (
RotaryEmbedding(
dim=self.head_dim,
use_xpos=rope_use_xpos,
seq_before_head_dim=False,
learned_freq=rope_learned_freq,
xpos_scale_base=rope_xpos_scale_base,
interpolate_factor=rope_interpolate_factor,
seq_before_head_dim=False,
)
if self.rope
else None
Expand Down Expand Up @@ -273,7 +268,10 @@ def forward(
if self.rope:
q_ = q.view(kv_bsz, self.num_heads, -1, self.head_dim)
k_ = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
q_, k_ = self.rotary_pos_embed.rotate_queries_and_keys(q_, k_)

q_ = self.rotary_pos_embed.rotate_queries_or_keys(q_)
k_ = self.rotary_pos_embed.rotate_queries_or_keys(k_)

q = q_.view(kv_bsz * self.num_heads, -1, self.head_dim)
k = k_.view(kv_bsz * self.num_heads, -1, self.head_dim)

Expand Down
1 change: 1 addition & 0 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import os
import sys
import math
import warnings
from itertools import accumulate
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
Expand Down

0 comments on commit bb64f47

Please sign in to comment.