Skip to content

Commit

Permalink
Fix DiverseBeamSearch so that no diversity groups will be dropped. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shuminghu committed Apr 12, 2023
1 parent 176cd93 commit 3f6ba43
Showing 1 changed file with 91 additions and 13 deletions.
104 changes: 91 additions & 13 deletions fairseq/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import math

from typing import List, Optional

import torch
Expand Down Expand Up @@ -113,6 +114,7 @@ def step(
scores: Optional[Tensor],
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
candidate_multiple: int = 2,
):
bsz, beam_size, vocab_size = lprobs.size()

Expand All @@ -128,9 +130,9 @@ def step(
top_prediction = torch.topk(
lprobs.view(bsz, -1),
k=min(
# Take the best 2 x beam_size predictions. We'll choose the first
# Take the best `candidate_muliple`(default 2) x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
candidate_multiple * beam_size,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
)
Expand Down Expand Up @@ -554,15 +556,57 @@ class DiverseBeamSearch(Search):
See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
Models" for details.
We only implement the Hamming Diversity penalty here, which performed best
in the original paper.
We implement cumulative diversity penalty here as default, optionally provide Hamming diversity described
in the original paper, and a way to interpolate between the two through diversity_discount.
Take the example below for illustration of cumulative diversity implemented.
A) I like dogs.
B) I like ____.
C) There are ___.
And we are at step=2, trying to fill in the blank:
Hamming diversity:
Penalty for B from A is 1 for "dogs" and 0 for any other words like "cats".
Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats".
Cumulative diversity (default):
Penalty for B from A is 3 for "dogs" and 0 for any other words like "cats".
Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats".
B and C differ because B matches with A for "I" and "like" at respective steps incurring 2 cumulative penalty.
Using divesrity_discount to interpolate between the two:
if diverstiy_discount = 0.5, then
Penalty for B from A is 1.75 (1 + 0.5 + 0.25) for "dogs" and 0 for any other words like "cats".
Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats".
"I" and "like" matched for B and A at step 0 and 1 respectively. Since "I" is two steps away and "like" is one step away, they are discounted by (0.5)^2 and 0.5 respectively.
When diversity_discount = 0, we recover Hammning diversity and when diversity_discount = 1, we recover cumulative diversity.
NB: During beam search for each diversity group, `candidate_mutiple` is set to 1 rather than BeamSearch default(2).
This is to ensure we have final `beam_size` candidates so that no diversity groups would be dropped during final token selection in sequence generation.
For full backwards compatibility, use diversity_discount=0 and candidate_multiple=2.
"""

def __init__(self, tgt_dict, num_groups, diversity_strength):
def __init__(
self,
tgt_dict,
num_groups,
diversity_strength,
diversity_discount=1.0,
candidate_multiple=1,
):
super().__init__(tgt_dict)
self.num_groups = num_groups
self.diversity_strength = -diversity_strength
self.beam = BeamSearch(tgt_dict)
self.diversity_discount = diversity_discount
self.candidate_multiple = candidate_multiple

# Float tensor to keep track of overlap between groups.
# Each token shared at the same step between two groups is counted as one.
# Then token counts are discounted by `diversity_discount` for every next timestep.
# Once initialized, dimension is batch_size * num_groups * num_groups.
self.group_overlap = torch.empty(0)

@torch.jit.export
def step(
Expand All @@ -582,13 +626,38 @@ def step(
# initialize diversity penalty
diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs)

scores_G, indices_G, beams_G = [], [], []
scores_G, beams_G = [], []

# pre-allocating tensor for indices for all groups
indices_G_stacked = torch.empty(
bsz,
int(beam_size / self.num_groups) * self.candidate_multiple,
self.num_groups,
dtype=torch.long,
device=lprobs.device,
)

for g in range(self.num_groups):
lprobs_g = lprobs[:, g :: self.num_groups, :]
scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None

diversity_buf.zero_()
# apply diversity penalty
if g > 0:
indices_ = indices_G_stacked[:, :, :g]
if step > 0:
penalty_val = 1 + self.group_overlap[original_batch_idxs, g, :g]
penalty_val = penalty_val.unsqueeze(1)
else:
penalty_val = torch.ones(bsz, 1, 1)
diversity_buf.scatter_add_(
1,
indices_.reshape(bsz, -1),
penalty_val.expand(indices_.size())
.reshape(bsz, -1)
.to(diversity_buf),
)

lprobs_g = torch.add(
lprobs_g,
other=diversity_buf.unsqueeze(1),
Expand All @@ -598,23 +667,32 @@ def step(
lprobs_g = lprobs_g.contiguous()

scores_buf, indices_buf, beams_buf = self.beam.step(
step, lprobs_g, scores_g
step, lprobs_g, scores_g, candidate_multiple=self.candidate_multiple
)
beams_buf.mul_(self.num_groups).add_(g)

scores_G.append(scores_buf.clone())
indices_G.append(indices_buf.clone())
beams_G.append(beams_buf.clone())

# update diversity penalty
diversity_buf.scatter_add_(
1, indices_buf, torch.ones(indices_buf.size()).to(diversity_buf)
)
indices_G_stacked[:, :, g] = indices_buf

# interleave results from different groups
scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1)
indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1)
indices_buf = indices_G_stacked.view(bsz, -1)
beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1)
# find num of overlapped tokens for each group pair
# then discount it for next timestamp
overlap = self.diversity_discount * torch.sum(
indices_G_stacked.unsqueeze(2).eq(indices_G_stacked.unsqueeze(3)), dim=1
)
if step == 0:
self.group_overlap = overlap
else:
self.group_overlap[original_batch_idxs] = (
self.group_overlap[original_batch_idxs] * self.diversity_discount
+ overlap
)

return scores_buf, indices_buf, beams_buf


Expand Down

0 comments on commit 3f6ba43

Please sign in to comment.