Skip to content

This is a code repository for the ACL 2022 paper "ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation", which redesigns the Transformer architecture from the ODE perspective via using high-order ODE solvers to enhance the residual connections.

License

libeineu/ODE-Transformer

Repository files navigation

ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation

This code is based on Fairseq v0.6.2. Note that the summarization task requires a newer version, e.g. Fairseq v0.10.2, we will release this code soon.

Requirements and Installation

  • PyTorch version >= 1.2.0
  • python version >= 3.6

Prepare Data

For Machine Translation

1、Download WMT14' En-De and WMT14' En-Fr

2、Preprocessed dataset

For Abstractive Summarization Task

2、Generate binary dataset data-bin/cnndm

bash preprocess_cnndaily_bin.sh path/to/cnndm_raw_data

For Grammatical Error Correction Task

2、Get CONLL14 test set

bash prepare_conll14_test_data.sh

3、Preprocessed dataset

bash preprocess_gec.sh

4、Generate binary dataset data-bin/BEA

bash preprocess_gec_bin.sh

Train

For WMT'14 En-De Task

Train a RK2-block $\textrm{learnable}, \gamma_i$ model (6-layer Big model)

bash train_wmt_en_de.sh

python3 -u train.py data-bin/$data_dir
  --distributed-world-size 8 -s src -t tgt
  --arch transformer_ode_t2t_wmt_en_de_big
  --share-all-embeddings
  --optimizer adam --clip-norm 0.0
  --adam-betas '(0.9, 0.997)'
  --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
  --lr 0.002 --min-lr 1e-09
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1
  --max-tokens 4096
  --update-freq 4
  --max-epoch 20
  --dropout 0.3 --attention-dropout 0.1 -- relu-dropout 0.1
  --no-progress-bar
  --log-interval 100
  --ddp-backend no_c10d
  --seed 1 
  --save-dir $save_dir
  --keep-last-epochs 10

For WMT'14 En-Fr Task

Train a RK2-block $\textrm{learnable}, \gamma_i$ model

bash train_wmt_en_fr.sh

python3 -u train.py data-bin/$data_dir
  --distributed-world-size 8 -s src -t tgt
  --arch transformer_ode_t2t_wmt_en_de_big
  --share-all-embeddings
  --optimizer adam --clip-norm 0.0
  --adam-betas '(0.9, 0.997)'
  --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
  --lr 0.002 --min-lr 1e-09
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1
  --max-tokens 4096
  --update-freq 8
  --max-epoch 20
  --dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
  --no-progress-bar
  --log-interval 100
  --ddp-backend no_c10d
  --seed 1 
  --save-dir $save_dir
  --keep-last-epochs 10

For Abstractive Summarization Task

Train a RK2-block $\textrm{learnable}, \gamma_i$ model

bash train_cnn_daily.sh

python3 -u train.py data-bin/$data_dir
  --distributed-world-size 8 -s src -t tgt
  --arch transformer_ode_t2t_wmt_en_de
  --share-all-embeddings
  --optimizer adam --clip-norm 0.0
  --adam-betas '(0.9, 0.997)'
  --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 8000
  --lr 0.002 --min-lr 1e-09
  --weight-decay 0.0001
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1
  --max-tokens 4096
  --update-freq 4
  --max-epoch 20
  --dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
  --truncate-source  --skip-invalid-size-inputs-valid-test --max-source-positions 500
  --no-progress-bar
  --log-interval 100
  --ddp-backend no_c10d
  --seed 1 
  --save-dir $save_dir
  --keep-last-epochs 10

For Grammatical Error Correction Task

Train a RK2-block $\textrm{learnable}, \gamma_i$ model

bash train_gec.sh

python3 -u train.py data-bin/$data_dir
  --distributed-world-size 8 -s src -t tgt
  --arch transformer_ode_t2t_wmt_en_de
  --share-all-embeddings
  --optimizer adam --clip-norm 0.0
  --adam-betas '(0.9, 0.98)'
  --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000
  --lr 0.0015 --min-lr 1e-09
  --weight-decay 0.0001
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1
  --max-tokens 4096
  --update-freq 2
  --max-epoch 55
  --dropout 0.2 --attention-dropout 0.1 -- relu-dropout 0.1
  --no-progress-bar
  --log-interval 100
  --ddp-backend no_c10d
  --seed 1 
  --save-dir $save_dir
  --keep-last-epochs 10
  --tensorboard-logdir $save_dir"

Evaluation

For WMT'14 En-De Task

We measure the performance through multi-bleu and sacrebleu

python3 generate.py \
data-bin/wmt-en2de \
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 0.6 \
--output hypo.txt \
--quiet \
--remove-bpe

For WMT'14 En-Fr Task

We measure the performance through multi-bleu and sacrebleu

python3 generate.py \
data-bin/wmt-en2fr \
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 0.6 \
--output hypo.txt \
--quiet \
--remove-bpe

For Abstractive Summarization Task

We use pyrouge as the scoring script.

python3 generate.py \
data-bin/$data_dir \
--path $model_dir/$checkpoint \
--gen-subset test \
--truncate-source \
--batch-size 32 \
--lenpen 2.0 \
--min-len 55 \
--max-len-b 140 \
--max-source-positions 500 \
--beam 4 \
--no-repeat-ngram-size 3 \
--remove-bpe

python3 get_rouge.py --decodes_filename $model_dir/hypo.sorted.tok --targets_filename cnndm.test.target.tok

For Grammatical Error Correction Task

We use m2scorer as the scoring script.

python3 generate.py \
data-bin/$data_dir \
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 2.0 \
--output hypo.txt \
--quiet \
--remove-bpe

path/to/m2scorer path/to/model_output path/to/conll14st-test.m2

Results

Machine Translation

Model Layer En-De En-Fr
Residual-block (baseline) 6-6 29.21 42.89
RK2-block (learnable $\gamma_i$) 6-6 30.53 43.59
Residual-block (baseline) 12-6 29.91 43.22
RK2-block (learnable $\gamma_i$) 12-6 30.76 44.11

Abstractive Summarization Task

Model RG-1 RG-2 RG-L
Residual-block 40.47 17.73 37.29
RK2-block ((learnable $\gamma_i$) 41.58 18.57 38.41
RK4-block 41.83 18.84 38.68

Grammatical Error Correction Task

Model Prec. Recall F_0.5
Residual-block 67.97 32.17 55.61
RK2-block ((learnable $\gamma_i$) 68.21 35.30 57.49
RK4-block 66.20 38.13 57.71

About

This is a code repository for the ACL 2022 paper "ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation", which redesigns the Transformer architecture from the ODE perspective via using high-order ODE solvers to enhance the residual connections.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages