Skip to content

Commit

Permalink
0.6.1 -> 0.6.2 (#577)
Browse files Browse the repository at this point in the history
Summary:
Changelog:
- 998ba4f: Add language models from Baevski & Auli (2018)
- 4294c4f: Add mixture of experts code from Shen et al. (2019)
- 0049349: Add example for multilingual training
- 48d9afb: Speed improvements, including fused operators from apex
- 44d27e6: Add Tensorboard support
- d17fa85: Add Adadelta optimizer
- 9e1c880: Add `FairseqEncoderModel`
- b65c579: Add `FairseqTask.inference_step` to modularize generate.py
- 2ad1178: Add back `--curriculum`
- Misc bug fixes and other features

Pull Request resolved: #577

Differential Revision: D14481233

Pulled By: myleott

fbshipit-source-id: 4ff8625ef1c0b24273fc65df7c5658e3c932e8b7
  • Loading branch information
myleott authored and facebook-github-bot committed Mar 15, 2019
1 parent 48d9afb commit e642252
Show file tree
Hide file tree
Showing 20 changed files with 300 additions and 90 deletions.
7 changes: 4 additions & 3 deletions README.md
Expand Up @@ -5,7 +5,7 @@ developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
Expand All @@ -18,7 +18,8 @@ of various sequence-to-sequence models, including:
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
- **_New_** [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- **_New_** [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md)
- **_New_** [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)

Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
Expand Down Expand Up @@ -88,7 +89,7 @@ We also have more detailed READMEs to reproduce results from specific papers:
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)

# Join the fairseq community

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Expand Up @@ -60,9 +60,9 @@
# built documents.
#
# The short X.Y version.
version = '0.6.1'
version = '0.6.2'
# The full version, including alpha/beta/rc tags.
release = '0.6.1'
release = '0.6.2'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
47 changes: 34 additions & 13 deletions eval_lm.py
Expand Up @@ -14,6 +14,7 @@
import torch

from fairseq import options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
Expand Down Expand Up @@ -65,11 +66,22 @@ def main(parsed_args):
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
setattr(args, arg, getattr(parsed_args, arg))

# reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window
task = tasks.setup_task(args)

# Load dataset splits
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
dataset = task.dataset(args.gen_subset)
if args.context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=args.tokens_per_sample,
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
Expand All @@ -84,7 +96,7 @@ def main(parsed_args):
print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(*[
Expand All @@ -97,7 +109,7 @@ def main(parsed_args):
).next_epoch_itr(shuffle=False)

gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary)
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

score_sum = 0.
count = 0
Expand All @@ -107,7 +119,11 @@ def main(parsed_args):
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_toks = set(
i
for i in range(len(task.source_dictionary))
if task.source_dictionary[i].endswith(bpe_cont)
)
bpe_len = len(bpe_cont)
else:
bpe_toks = None
Expand All @@ -117,31 +133,36 @@ def main(parsed_args):

with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()

for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue

sample = utils.move_to_cuda(sample) if use_cuda else sample

gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])

for hypos_i in hypos:
hypo = hypos_i[0]
pos_scores = hypo['positional_scores']

tokens = hypo['tokens']
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()

skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0

inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
task.target_dictionary.string(tokens[inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
Expand All @@ -150,9 +171,9 @@ def main(parsed_args):
w = ''
word_prob = []
is_bpe = False
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
for i in range(len(tokens)):
w_ind = tokens[i].item()
w += task.source_dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
Expand All @@ -161,7 +182,7 @@ def main(parsed_args):

next_prob = None
ind = i + 1
while ind < len(hypo['tokens']):
while ind < len(tokens):
if pos_scores[ind].item() != 0:
next_prob = pos_scores[ind]
break
Expand Down
1 change: 0 additions & 1 deletion examples/.gitignore
@@ -1,3 +1,2 @@
*/*
!*/*.sh
!*/*.md
26 changes: 0 additions & 26 deletions examples/conv_lm/README.md

This file was deleted.

40 changes: 32 additions & 8 deletions examples/language_model/README.md
Expand Up @@ -2,10 +2,10 @@

## Pre-trained models

Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2)
Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)

## Example usage

Expand All @@ -16,6 +16,8 @@ These scripts provide an example of pre-processing data for the Language Modelin
Provides an example of pre-processing for [WikiText-103 language modeling task](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):

Example usage:

Prepare data:
```
$ cd examples/language_model/
$ bash prepare-wikitext-103.sh
Expand All @@ -27,17 +29,39 @@ $ TEXT=examples/language_model/wikitext-103
$ fairseq-preprocess --only-source \
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103
```

Train a transformer language model with adaptive inputs ([Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](transformer_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
$ mkdir -p checkpoints/transformer_wikitext-103
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024
```


Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
$ mkdir -p checkpoints/fconv_wikitext-103
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/fconv_wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
--ddp-backend=no_c10d
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt'
```
19 changes: 19 additions & 0 deletions examples/language_model/conv_lm/README.md
@@ -0,0 +1,19 @@
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)

## Example usage

See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103
using the `fconv_lm_dauphin_wikitext103` model architecture.

## Citation

```bibtex
@inproceedings{dauphin2017language,
title={Language Modeling with Gated Convolutional Networks},
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
pages={933--941},
year={2017},
organization={JMLR}
}
```
14 changes: 7 additions & 7 deletions examples/language_model/prepare-wikitext-103.sh
100755 → 100644
Expand Up @@ -21,13 +21,13 @@ for ((i=0;i<${#URLS[@]};++i)); do
echo "$url not successfully downloaded."
exit -1
fi
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
fi
fi
done
cd ..
26 changes: 26 additions & 0 deletions examples/language_model/transformer_lm/README.md
@@ -0,0 +1,26 @@
# Adaptive Input Representations for Neural Language Modeling (Baevski and Auli; 2018)

## Pre-trained models

Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)

## Example usage

See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103
using the `transformer_lm_wiki103` model architecture.

## Citation

```bibtex
@inproceedings{
baevski2018adaptive,
title={Adaptive Input Representations for Neural Language Modeling},
author={Alexei Baevski and Michael Auli},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=ByxZX20qFQ},
}
```
2 changes: 1 addition & 1 deletion fairseq/__init__.py
Expand Up @@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory.

__all__ = ['pdb']
__version__ = '0.6.1'
__version__ = '0.6.2'

import fairseq.criterions
import fairseq.models
Expand Down
2 changes: 2 additions & 0 deletions fairseq/data/__init__.py
Expand Up @@ -11,6 +11,7 @@
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset
Expand All @@ -35,6 +36,7 @@
'IndexedDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'LMContextWindowDataset',
'MonolingualDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
Expand Down

0 comments on commit e642252

Please sign in to comment.