New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Non-autoregressive Transformer] Add GLAT, CTC, DS #4431
base: main
Are you sure you want to change the base?
Conversation
Hi, thanks for this great integration of NAT codes. Could you please provide an example to show how to train a |
Sure, as written above, the main flag for that is fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --activation-fn gelu --adam-betas '(0.9, 0.98)' --apply-bert-init --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --length-loss-factor 0.1 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --pred-length-offset --share-all-embeddings --task translation_lev --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings Similarly, for vanilla CTC: fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings As you can see, the main flags to enable the methods are passed and can also be combined for CTC + GLAT (given that the C++ code is added as stated above): fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings For some of the hyperparameter choices, please see the paper (above is for WMT'14 EN-DE) ! Of course, Let me know in case you run into any issues, I needed to strip a few internal flags so hopefully I didn't miss anything! |
Dear Robin, thanks for your help, I have successfully finished the training process, could you kindly provide the test script? |
Sure, given that you have averaged your checkpoints and saved it in a file e.g.
This will generate hypothesis for the test set that you will need to score with
Note that |
Sincerely thanks! |
No problem at all, please let me know in case you run into any issues! |
How to implement --nbest? |
This PR adds the code for the following methods to the Non-Autoregressive Transformer:
Important to note is that this code still has one more dependency on the C++ code from torch_imputer that is currently not integrated in this PR. We leave it up to the fairseq team to decide how they want to include the
best_alignment
method forfairseq/models/nat/nonautoregressive_transformer.py
(l. 171). Both, building a pip package and importing it or directly copying over the code from the respective repository would work. It is used for getting Viterbi-aligned target tokens when using CTC + GLAT jointly.Main flags for training using any of the above methods are:
--use-glat
--use-ctc-decoder --ctc-src-upsample-scale 2
--use-deep-supervision
These are also supported jointly. Once this PR has been integrated, we'll work on getting a follow-up PR up for the required inference speed improvements i.e. Shortlists and Average Attention (see below paper). As these are not specific to non-autoregressive models, we decided to keep them separate.
If anyone using this code finds it helpful, please consider citing our associated paper: