Skip to content

Commit

Permalink
Online backtranslation module
Browse files Browse the repository at this point in the history
Co-authored-by: liezl200 <lie@fb.com>
  • Loading branch information
myleott and liezl200 committed Sep 25, 2018
1 parent a4fe8c9 commit 864b89d
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 72 deletions.
131 changes: 131 additions & 0 deletions fairseq/data/backtranslation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.


from fairseq import sequence_generator

from . import FairseqDataset, language_pair_dataset


class BacktranslationDataset(FairseqDataset):
def __init__(self, args, tgt_dataset, tgt_dict, backtranslation_model):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation_model, and returns the
corresponding {generated src, input tgt} batch
Args:
args: generation args for the backtranslation SequenceGenerator'
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_backtranslation_dataset.
tgt_dataset: dataset which will be used to build self.tgt_dataset --
a LanguagePairDataset with tgt dataset as the source dataset and
None as the target dataset.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches
"""
self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset,
src_sizes=None,
src_dict=tgt_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
self.backtranslation_generator = sequence_generator.SequenceGenerator(
[backtranslation_model],
tgt_dict,
unk_penalty=args.backtranslation_unkpen,
sampling=args.backtranslation_sampling,
beam_size=args.backtranslation_beam,
)
self.backtranslation_max_len_a = args.backtranslation_max_len_a
self.backtranslation_max_len_b = args.backtranslation_max_len_b
self.backtranslation_beam = args.backtranslation_beam

def __getitem__(self, index):
"""
Returns a single sample. Multiple samples are fed to the collater to
create a backtranslation batch. Note you should always use collate_fn
BacktranslationDataset.collater() below if given the option to
specify which collate_fn to use (e.g. in a dataloader which uses this
BacktranslationDataset -- see corresponding unittest for an example).
"""
return self.tgt_dataset[index]

def __len__(self):
"""
The length of the backtranslation dataset is the length of tgt.
"""
return len(self.tgt_dataset)

def collater(self, samples):
"""
Using the samples from the tgt dataset, load a collated tgt sample to
feed to the backtranslation model. Then take the generated translation
with best score as the source and the orignal net input as the target.
"""
collated_tgt_only_sample = self.tgt_dataset.collater(samples)
backtranslation_hypos = self._generate_hypotheses(collated_tgt_only_sample)

# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
generated_samples = []
for input_sample, hypos in zip(samples, backtranslation_hypos):
generated_samples.append(
{
"id": input_sample["id"],
"source": hypos[0]["tokens"], # first hypo is best hypo
"target": input_sample["source"],
}
)

return language_pair_dataset.collate(
samples=generated_samples,
pad_idx=self.tgt_dataset.src_dict.pad(),
eos_idx=self.tgt_dataset.src_dict.eos(),
)

def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)

def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """
self.tgt_dataset.num_tokens(index)

def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """
self.tgt_dataset.ordered_indices

def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """
self.tgt_dataset.valid_size(index, max_positions)

def _generate_hypotheses(self, sample):
"""
Generates hypotheses from a LanguagePairDataset collated / batched
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
self.backtranslation_generator.cuda()
input = sample["net_input"]
srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate(
input,
maxlen=int(
self.backtranslation_max_len_a * srclen + self.backtranslation_max_len_b
),
)
return hypos
70 changes: 70 additions & 0 deletions tests/test_backtranslation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import argparse
import unittest

import tests.utils as test_utils
import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset


class TestBacktranslationDataset(unittest.TestCase):
def setUp(self):
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
test_utils.sequence_generator_setup()
)
backtranslation_args = argparse.Namespace()

"""
Same as defaults from fairseq/options.py
"""
backtranslation_args.backtranslation_unkpen = 0
backtranslation_args.backtranslation_sampling = False
backtranslation_args.backtranslation_max_len_a = 0
backtranslation_args.backtranslation_max_len_b = 200
backtranslation_args.backtranslation_beam = 2

self.backtranslation_args = backtranslation_args

dummy_src_samples = self.src_tokens

self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)

def test_backtranslation_dataset(self):
backtranslation_dataset = BacktranslationDataset(
args=self.backtranslation_args,
tgt_dataset=self.tgt_dataset,
tgt_dict=self.tgt_dict,
backtranslation_model=self.model,
)
dataloader = torch.utils.data.DataLoader(
backtranslation_dataset,
batch_size=2,
collate_fn=backtranslation_dataset.collater,
)
backtranslation_batch_result = next(iter(dataloader))

eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2

# Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0]
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
tgt_tokens = backtranslation_batch_result["target"]

self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)

def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)


if __name__ == "__main__":
unittest.main()
82 changes: 10 additions & 72 deletions tests/test_sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,79 +18,17 @@
class TestSequenceGenerator(unittest.TestCase):

def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5

# construct source data
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])
self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
test_utils.sequence_generator_setup()
)
self.encoder_input = {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths,
'src_tokens': src_tokens, 'src_lengths': src_lengths,
}

args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[0.0, unk, 0.9, 0.1], # beam 1
[0.0, unk, 0.9, 0.1], # beam 2
# sentence 2:
[0.0, unk, 0.7, 0.3],
[0.0, unk, 0.7, 0.3],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[0.0, unk, 0.9, 0.1], # w2: 0.1
# sentence 2:
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[0.00, unk, 0.10, 0.9], # w2: 0.3
]),
# step 2:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
[0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
]),
# step 3:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]

task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary

def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
Expand All @@ -109,7 +47,7 @@ def test_without_normalization(self):
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
Expand All @@ -127,7 +65,7 @@ def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
Expand All @@ -145,7 +83,7 @@ def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
Expand All @@ -162,7 +100,7 @@ def test_with_lenpen_favoring_long_hypos(self):
def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
Expand All @@ -179,7 +117,7 @@ def test_maxlen(self):
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
Expand Down

0 comments on commit 864b89d

Please sign in to comment.