Skip to content

Latest commit

 

History

History
163 lines (146 loc) · 12.7 KB

BUGFREE_CONFORMER.md

File metadata and controls

163 lines (146 loc) · 12.7 KB

Correctness of Conformer implementation (ACL 2024)

This README contains the instructions to replicate the training and evaluation of the models in the paper When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP published at ACL 2024. In addition, we release the pre-trained models used in the paper.

Setup

Clone this repository and install it as explained in the original Fairseq(-py). For the experiments we used MuST-C, make sure to download the corpus. Follow the preprocessing steps of Speechformer to preprocess the MuST-C data.

Training

The bug-free version of the Conformer-based model can be trained by passing the target language (LANG), the folder containing the MuST-C preprocessed data (MUSTC_ROOT), the task in TASK (either asr or st), and the directory in which the checkpoints and training log will be saved (SAVE_DIR).

LANG=$1
MUSTC_ROOT=$2
TASK=$3
SAVE_DIR=$4

mkdir -p $SAVE_DIR

python ${FBK_fairseq}/train.py ${MUSTC_ROOT} \
        --train-subset train_${TASK}_src --valid-subset dev_${TASK}_src \
        --user-dir examples/speech_to_text --seed 1 \
        --num-workers 1 --max-update 100000 --patience 10 --keep-last-epochs 12 \
        --max-tokens 40000 --update-freq 4 \
        --task speech_to_text_ctc --config-yaml config.yaml  \
        --criterion ctc_multi_loss \
        --underlying-criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
        --arch conformer \
        --ctc-encoder-layer 8 --ctc-weight 0.5  --ctc-compress-strategy avg \
        --optimizer adam --adam-betas '(0.9, 0.98)' \
        --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 25000 \
        --clip-norm 10.0 \
        --skip-invalid-size-inputs-valid-test \
        --save-dir ${SAVE_DIR} \
        --log-format simple > $SAVE_DIR/train.log 2> $SAVE_DIR/train.err

python ${FBK_fairseq}/scripts/average_checkpoints.py \
        --input $SAVE_DIR --num-epoch-checkpoints 5 \
        --checkpoint-upper-bound $(ls $SAVE_DIR | head -n 5 | tail -n 1 | grep -o "[0-9]*") \
        --output $SAVE_DIR/avg5.pt

if [ -f $SAVE_DIR/avg5.pt ]; then
  rm $SAVE_DIR/checkpoint??.pt
fi

The script will train the model and make the average of 5 checkpoints (best, 2 preceedings and 2 succeedings).

To remove the CTC Compression from the model, remove --ctc-compress-strategy avg from the script.

Due to its increased training time, the bug-fix of the padding problem present in the relative positional encodings (🪲3) can be disabled by adding --batch-unsafe-relative-shift to the script.

To remove the bug-fix relative to the Convolution Module (🪲1) revert the commit: [!63][CONFORMER] Correction to the Convolutional Layer of Conformer for missing padding.

To remove the bug-fix relative to the Initial Subsampling (🪲2) rever the commit: [!69][TRANSFORMER][CONFORMER][BUG] Fix padding in initial convolutional layers.

To enable or disable TF32 you need to respectively set to True or False the flags torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32. Please notice that their default value depends on the version of PyTorch you are using (our experiments have been carried out with PyTorch 1.11, where they are enabled by default), and they have effect only on Ampere GPU (if you are using a V100 GPU, for instance, TF32 cannot be enabled). For more information, refer to the official pytorch page.

Evaluation

Generate the output by varying the batch size SENTENCES={1, 10, 100} and check its independence from it: the generated files must be the same.

python ${FBK_fairseq}/fairseq_cli/generate.py ${MUSTC_ROOT} \
        --user-dir examples/speech_to_text \
        --config-yaml config.yaml --gen-subset tst-COMMON_st_src \
        --max-sentences ${SENTENCES} \
        --max-source-positions 10000 --max-target-positions 1000 \
        --task speech_to_text_ctc \
        --criterion ctc_multi_loss --underlying-criterion label_smoothed_cross_entropy \
        --beam 5 --no-repeat-ngram-size 5 --path ${SAVE_DIR}/avg5.pt > ${SAVE_DIR}/tst-COMMON.${SENTENCES}.out

Pretrained models

Common files:

bpe_tokenizer:
  bpe: sentencepiece
  sentencepiece_model: tgtdict.model
bpe_tokenizer_src:
  bpe: sentencepiece
  sentencepiece_model: srcdict.model
input_channels: 1
input_feat_per_channel: 80
sampling_alpha: 1.0
specaugment:
  freq_mask_F: 27
  freq_mask_N: 1
  time_mask_N: 1
  time_mask_T: 100
  time_mask_p: 1.0
  time_wrap_W: 0
transforms:
  '*':
  - utterance_cmvn
  _train:
  - utterance_cmvn
  - specaugment
vocab_filename: tgtdict.txt
vocab_filename_src: srcdict.txt

Checkpoints

Code Model en (ASR) en-de en-es en-fr en-it en-nl en-pt en-ro en-ru
Conformer ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt
+ CTC Compression ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt
🪲 Conformer ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt
🪲 + CTC Compression ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt ckp.pt

Citation

@inproceedings{papi-et-al-2024-when,
  title={{When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP}},
  author={Papi, Sara and Gaido, Marco and Pilzer, Andrea and Negri, Matteo},
  booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
  address = "Bangkok, Thailand",
  year={2024}
}