Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Non-autoregressive Transformer] Add GLAT, CTC, DS #4431

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SirRob1997
Copy link

@SirRob1997 SirRob1997 commented May 21, 2022

This PR adds the code for the following methods to the Non-Autoregressive Transformer:

Important to note is that this code still has one more dependency on the C++ code from torch_imputer that is currently not integrated in this PR. We leave it up to the fairseq team to decide how they want to include the best_alignment method for fairseq/models/nat/nonautoregressive_transformer.py (l. 171). Both, building a pip package and importing it or directly copying over the code from the respective repository would work. It is used for getting Viterbi-aligned target tokens when using CTC + GLAT jointly.

Main flags for training using any of the above methods are:

  • GLAT: --use-glat
  • CTC: --use-ctc-decoder --ctc-src-upsample-scale 2
  • DS: --use-deep-supervision

These are also supported jointly. Once this PR has been integrated, we'll work on getting a follow-up PR up for the required inference speed improvements i.e. Shortlists and Average Attention (see below paper). As these are not specific to non-autoregressive models, we decided to keep them separate.

If anyone using this code finds it helpful, please consider citing our associated paper:

Abstract: Non-autoregressive approaches aim to improve the inference speed of translation models by only requiring a single forward pass to generate the output sequence instead of iteratively producing each predicted token. Consequently, their translation quality still tends to be inferior to their autoregressive counterparts due to several issues involving output token interdependence. In this work, we take a step back and revisit several techniques that have been proposed for improving non-autoregressive translation models and compare their combined translation quality and speed implications under third-party testing environments. We provide novel insights for establishing strong baselines using length prediction or CTC-based architecture variants and contribute standardized BLEU, chrF++, and TER scores using sacreBLEU on four translation tasks, which crucially have been missing as inconsistencies in the use of tokenized BLEU lead to deviations of up to 1.7 BLEU points. Our open-sourced code is integrated into fairseq for reproducibility.

@misc{schmidt2022nat,
  url = {https://arxiv.org/abs/2205.10577}, 
  author = {Schmidt, Robin M. and Pires, Telmo and Peitz, Stephan and Lööf, Jonas},
  title = {Non-Autoregressive Neural Machine Translation: A Call for Clarity},
  publisher = {arXiv},
  year = {2022}
}

@xcfcode
Copy link

xcfcode commented Jun 4, 2022

Hi, thanks for this great integration of NAT codes. Could you please provide an example to show how to train a GLAT model?

@SirRob1997
Copy link
Author

SirRob1997 commented Jun 4, 2022

Sure, as written above, the main flag for that is --use-glat which will enable the glancing sampling. Given that you ran fairseq-preprocess and you have your data correctly set up in a folder data-bin you should be able to run a training run for GLAT with:

fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --activation-fn gelu --adam-betas '(0.9, 0.98)' --apply-bert-init --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --length-loss-factor 0.1 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --pred-length-offset --share-all-embeddings --task translation_lev --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings

Similarly, for vanilla CTC:

fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings

As you can see, the main flags to enable the methods are passed and can also be combined for CTC + GLAT (given that the C++ code is added as stated above):

fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings

For some of the hyperparameter choices, please see the paper (above is for WMT'14 EN-DE) !

Of course, max-tokens and lr are a little specific to our setup (number of GPUs, batch size) and might need some tuning to most effectively utilise your available GPU resources. My guess would be that you need to reduce both of them since we train on multiple A100's and as a result our batch size is quite large.

Let me know in case you run into any issues, I needed to strip a few internal flags so hopefully I didn't miss anything!

@xcfcode
Copy link

xcfcode commented Jun 29, 2022

Dear Robin, thanks for your help, I have successfully finished the training process, could you kindly provide the test script?

@SirRob1997
Copy link
Author

SirRob1997 commented Jun 29, 2022

Sure, given that you have averaged your checkpoints and saved it in a file e.g. ckpts_last_5.pt running inference on the test set works with the following command:

fairseq-generate data-bin --path ckpts_last_5.pt --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 --batch-size 128 --beam 1 --remove-bpe  --task translation_lev --gen-subset test 

This will generate hypothesis for the test set that you will need to score with sacrebleu something like this should work to generate BLEU, chrF++, and case-sensitive TER metrics for sacrebleu==2.0.0:

sacrebleu -i  test.hyp -t wmt14/full -l en-de -m bleu chrf ter --chrf-word-order 2 --ter-case-sensitive

Note that EN-DE and DE-EN use wmt14/full while EN-RO and RO-EN use wmt16.

@xcfcode
Copy link

xcfcode commented Jun 30, 2022

Sincerely thanks!

@SirRob1997
Copy link
Author

No problem at all, please let me know in case you run into any issues!

@PPPNut
Copy link

PPPNut commented Apr 5, 2023

How to implement --nbest?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants