Skip to content

Commit

Permalink
0.4.0 -> 0.5.0
Browse files Browse the repository at this point in the history
Changelog:
- 97b58b4: add Transformer model from Vaswani et al. (2017)
- b2374e5: faster Transformer inference with improved caching
- 2d27ae0: simulate large mini-batch training with delayed updates (`--update-freq`)
- 7ee1d28: add FP16 training support (`--fp16`)
- 2a84f46: faster inference by removing completed sentences from the batch
- 663fd80: batched interactive generation
- 4c2ef2d: add language modeling / gated convolutional model from Dauphin et al. (2017)
- b59815b: add Hierarchical Neural Story Generation model from Fan et al. (2018)
- ff68a9e: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
  • Loading branch information
myleott committed Jun 15, 2018
2 parents ec0031d + 5383b5d commit 388c520
Show file tree
Hide file tree
Showing 74 changed files with 5,297 additions and 1,692 deletions.
118 changes: 85 additions & 33 deletions README.md
@@ -1,14 +1,25 @@
# Introduction

Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization and other text generation tasks. It provides reference implementations of various sequence-to-sequence models, including:
Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks. It provides reference implementations of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956)
- **_New_** [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956)
- **_New_** [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)

Fairseq features multi-GPU (distributed) training on one machine or across multiple machines, fast beam search generation on both CPU and GPU, and includes pre-trained models for several benchmark translation datasets.
- **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- **_New_** [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187)

Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU
- large mini-batch training (even on a single GPU) via delayed updates
- fast half-precision floating point (FP16) training

We also provide [pre-trained models](#pre-trained-models) for several benchmark translation datasets.

![Model](fairseq.gif)

Expand Down Expand Up @@ -38,6 +49,7 @@ The following command-line tools are provided:
* `python generate.py`: Translate pre-processed data with a trained model
* `python interactive.py`: Translate raw text with a trained model
* `python score.py`: BLEU scoring of generated translations against reference translations
* `python eval_lm.py`: Language model evaluation

## Evaluating Pre-trained Models
First, download a pre-trained model along with its vocabularies:
Expand Down Expand Up @@ -71,16 +83,19 @@ This generation script produces four types of outputs: a line prefixed with *S*

Check [below](#pre-trained-models) for a full list of pre-trained models available.


## Training a New Model

The following tutorial is for machine translation.
For an example of how to use Fairseq for other tasks, such as [language modeling](examples/language_model/README.md), please see the `examples/` directory.

### Data Pre-processing

Fairseq contains example pre-processing scripts for several translation datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT 2014 (English-German).
To pre-process and binarize the IWSLT dataset:
```
$ cd data/
$ cd examples/translation/
$ bash prepare-iwslt14.sh
$ cd ..
$ cd ../..
$ TEXT=data/iwslt14.tokenized.de-en
$ python preprocess.py --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
Expand Down Expand Up @@ -125,15 +140,25 @@ BPE continuation markers can be removed with the `--remove-bpe` flag.

# Pre-trained Models

We provide the following pre-trained fully convolutional sequence-to-sequence models:
We provide the following pre-trained models and pre-processed, binarized test sets:

### Translation

* [wmt14.en-fr.fconv-py.tar.bz2](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) including vocabularies
* [wmt14.en-de.fconv-py.tar.bz2](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-de.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) including vocabularies
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-de.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.joined-dict.transformer.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt16.en-de.joined-dict.newstest2014.tar.bz2)

In addition, we provide pre-processed and binarized test sets for the models above:
* [wmt14.en-fr.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-French
* [wmt14.en-fr.ntst1213.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.ntst1213.tar.bz2): newstest2012 and newstest2013 test sets for WMT14 English-French
* [wmt14.en-de.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-de.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-German
### Language models

Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wiki103_test_lm.tar.bz2)

### Usage

Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```
Expand All @@ -153,39 +178,66 @@ $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```

# Distributed version
# Large mini-batch training with delayed updates

The `--update-freq` option can be used to accumulate gradients from multiple mini-batches and delay updating,
creating a larger effective batch size.
Delayed updates can also improve training speed by reducing inter-GPU communication costs and by saving idle time caused by variance in workload across GPUs.
See [Ott et al. (2018)](https://arxiv.org/abs/1806.00187) for more details.

To train on a single GPU with an effective batch size that is equivalent to training on 8 GPUs:
```
CUDA_VISIBLE_DEVICES=0 python train.py --update-freq 8 (...)
```

# Training with half precision floating point (FP16)

> Note: FP16 training requires a Volta GPU and CUDA 9.1 or greater
Recent GPUs enable efficient half precision floating point computation, e.g., using [Nvidia Tensor Cores](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html).

Fairseq supports FP16 training with the `--fp16` flag:
```
python train.py --fp16 (...)
```

# Distributed training

Distributed training in fairseq is implemented on top of [torch.distributed](http://pytorch.org/docs/master/distributed.html).
Training begins by launching one worker process per GPU.
These workers discover each other via a unique host and port (required) that can be used to establish an initial connection.
Additionally, each worker is given a rank, that is a unique number from 0 to n-1 where n is the total number of GPUs.
Additionally, each worker has a rank, that is a unique number from 0 to n-1 where n is the total number of GPUs.

If you run on a cluster managed by [SLURM](https://slurm.schedmd.com/) you can train a large English-French model on the WMT 2014 dataset on 16 nodes with 8 GPUs each (in total 128 GPUs) using this command:

```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ PORT=9218 # any available tcp port that can be used by the trained to establish initial connection
$ sbatch --job-name fairseq-py --gres gpu:8 --nodes 16 --ntasks-per-node 8 \
--cpus-per-task 10 --no-requeue --wrap 'srun --output train.log.node%t \
--error train.stderr.node%t.%j python train.py $DATA --distributed-world-size 128 \
--distributed-port $PORT --force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ PORT=9218 # any available TCP port that can be used by the trainer to establish initial connection
$ sbatch --job-name fairseq-py --gres gpu:8 --cpus-per-task 10 \
--nodes 16 --ntasks-per-node 8 \
--wrap 'srun --output train.log.node%t --error train.stderr.node%t.%j \
python train.py $DATA \
--distributed-world-size 128 \
--distributed-port $PORT \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001'
```

Alternatively you'll need to manually start one process per each GPU:
Alternatively you can manually start one process per GPU:
```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ HOST_PORT=your.devserver.com:9218 # has to be one of the hosts that will be used by the job \
and the port on that host has to be available
$ RANK=... # the rank of this process, has to go from 0 to 127 in case of 128 GPUs
$ python train.py $DATA --distributed-world-size 128 \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001 \
--distributed-init-method='tcp://$HOST_PORT' --distributed-rank=$RANK
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ HOST_PORT=master.devserver.com:9218 # one of the hosts used by the job
$ RANK=... # the rank of this process, from 0 to 127 in case of 128 GPUs
$ python train.py $DATA \
--distributed-world-size 128 \
--distributed-init-method 'tcp://$HOST_PORT' \
--distributed-rank $RANK \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001
```

# Join the fairseq community
Expand Down
2 changes: 1 addition & 1 deletion distributed_train.py
Expand Up @@ -10,7 +10,7 @@
import socket
import subprocess

from singleprocess_train import main as single_process_main
from train import main as single_process_main
from fairseq import distributed_utils, options


Expand Down
78 changes: 78 additions & 0 deletions eval_lm.py
@@ -0,0 +1,78 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import numpy as np
import torch

from fairseq import data, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


def main(args):
assert args.path is not None, '--path required for evaluation!'

if args.tokens_per_sample is None:
args.tokens_per_sample = 1024
print(args)

use_cuda = torch.cuda.is_available() and not args.cpu

# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()

itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(),
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)

gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()

score_sum = 0.
count = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})

avg_nll_loss = -score_sum / count
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))


if __name__ == '__main__':
parser = options.get_eval_lm_parser()
args = options.parse_args_and_arch(parser)
main(args)
34 changes: 34 additions & 0 deletions examples/language_model/README.md
@@ -0,0 +1,34 @@
Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit

These scripts provide an example of pre-processing data for the Language Modeling task.

# prepare-wikitext-103.sh

Provides an example of pre-processing for [WikiText-103 language modeling task](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset):

Example usage:
```
$ cd examples/language_model/
$ bash prepare-wikitext-103.sh
$ cd ../..
# Binarize the dataset:
$ TEXT=examples/language_model/wikitext-103
$ python preprocess.py --only-source \
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103
# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
$ python train.py --task language_modeling data-bin/wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
# Evaluate:
$ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
```
33 changes: 33 additions & 0 deletions examples/language_model/prepare-wikitext-103.sh
@@ -0,0 +1,33 @@
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh

URLS=(
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip"
)
FILES=(
"wikitext-103-v1.zip"
)

for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
fi
fi
done
cd ..
30 changes: 30 additions & 0 deletions examples/stories/README.md
@@ -0,0 +1,30 @@
FAIR Sequence-to-Sequence Toolkit for Story Generation

The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.

The dataset can be downloaded like this:

```
curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvjf -
```

and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833, where only the first 1000 words of each story are modeled.


Example usage:
```
# Binarize the dataset:
$ TEXT=examples/stories/writingPrompts
$ python preprocess.py --source-lang wp_source --target-lang wp_target \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/writingPrompts --thresholdtgt 10 --thresholdsrc 10
# Train the model:
$ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
# Train a fusion model:
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
# Generate:
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1
```

0 comments on commit 388c520

Please sign in to comment.