Skip to content

zhengxxn/adaptive-knn-mt

Repository files navigation

Adaptive kNN-MT

Code for our ACL 2021 paper "Adaptive Nearest Neighbor Machine Translation". Please cite our paper if you find this repository helpful in your research:

@inproceedings{zheng-etal-2021-adaptive,
    title = "Adaptive Nearest Neighbor Machine Translation",
    author = "Zheng, Xin  and Zhang, Zhirui  and Guo, Junliang  and Huang, Shujian  and Chen, Boxing  and Luo, Weihua  and Chen, Jiajun",
    booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 2: Short Papers)",
    month = aug,
    year = "2021",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.acl-short.47",
    doi = "10.18653/v1/2021.acl-short.47",
    pages = "368--374",
    abstract = "kNN-MT, recently proposed by Khandelwal et al. (2020a), successfully combines pre-trained neural machine translation (NMT) model with token-level k-nearest-neighbor (kNN) retrieval to improve the translation accuracy. However, the traditional kNN algorithm used in kNN-MT simply retrieves a same number of nearest neighbors for each target token, which may cause prediction errors when the retrieved neighbors include noises. In this paper, we propose Adaptive kNN-MT to dynamically determine the number of k for each target token. We achieve this by introducing a light-weight Meta-k Network, which can be efficiently trained with only a few training samples. On four benchmark machine translation datasets, we demonstrate that the proposed method is able to effectively filter out the noises in retrieval results and significantly outperforms the vanilla kNN-MT model. Even more noteworthy is that the Meta-k Network learned on one domain could be directly applied to other domains and obtain consistent improvements, illustrating the generality of our method. Our implementation is open-sourced at https://github.com/zhengxxn/adaptive-knn-mt.",
}

This project implements our Adaptive kNN-MT as well as Vanilla kNN-MT. The implementation is build upon fairseq, and heavily inspired by knn-lm, many thanks to the authors for making their code avaliable.

The NJUNLP/knn-box toolkit also implements Adaptive kNN-MT and other kNN models. We highly recommend to use NJUNLP/knn-box instead of this repo in the future, it can be more clear and easy to use, and can do some visualization.

Requirements and Installation

  • pytorch version >= 1.5.0
  • python version >= 3.6
  • faiss-gpu >= 1.6.5
  • pytorch_scatter = 2.0.5
  • 1.19.0 <= numpy < 1.20.0

You can install this project by

pip install --editable ./

Attention!

2022/12/26 In our earlier implementations (before commit 7997d11), kNN-MT / adaptive kNN-MT were much slower than NMT for implementation-specific reasons. We repaired this problem and now kNN-MT / adaptive kNN-MT are only a little slower than NMT. To get high inference speed, if your code version is older than commit 7997d11, please pull the latest code, or refer to Pull Request #9 to modify your code.

Instructions

We use an example to show how to use our codes.

Pre-trained Model and Data

The pre-trained translation model can be downloaded from this site. We use the De->En Single Model for all experiments.

The raw data can be downloaded in this site, and you should preprocess them with moses toolkits and the bpe-codes provided by pre-trained model. For convenience, We also provide pre-processed data.

Create Datastore

This script will create datastore (includes key.npy and val.npy) for the data.

DSTORE_SIZE=3613350
MODEL_PATH=/path/to/pretrained_model_path
DATA_PATH=/path/to/fairseq_preprocessed_data_path
DATASTORE_PATH=/path/to/save_datastore
PROJECT_PATH=/path/to/ada_knnmt

mkdir -p $DATASTORE_PATH

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/save_datastore.py $DATA_PATH \
    --dataset-impl mmap \
    --task translation \
    --valid-subset train \
    --path $MODEL_PATH \
    --max-tokens 4096 --skip-invalid-size-inputs-valid-test \
    --decoder-embed-dim 1024 --dstore-fp16 --dstore-size $DSTORE_SIZE --dstore-mmap $DATASTORE_PATH
 
# 4096 and 1024 depend on your device and model separately

The DSTORE_SIZE depends on the num of tokens of target language train data. You can get it by two ways:

  • find it in preprocess.log file, which is created by fairseq-process and in data binary folder.
  • calculate wc -l + wc -w of raw data file.

The datastore sizes we used in our paper are listed as below:

IT Medical koran Law
3613350 6903320 524400 19070000

Build Faiss Index

This script will build faiss index for keys, which is used for fast knn search. when the knn_index is build, you can remove keys.npy to save the hard disk space.

PROJECT_PATH=/path/to/ada_knnmt
DSTORE_PATH=/path/to/saved_datastore
DSTORE_SIZE=3613350

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/train_datastore_gpu.py \
  --dstore_mmap $DSTORE_PATH \
  --dstore_size $DSTORE_SIZE \
  --dstore_fp16 \
  --faiss_index ${DSTORE_PATH}/knn_index \
  --ncentroids 4096 \
  --probe 32 \
  --dimension 1024

Train Adaptive kNN-MT Model


DSTORE_SIZE=3613350
DATA_PATH=/path/to/fairseq_preprocessed_data_path
PROJECT_PATH=/path/to/ada_knnmt
MODEL_PATH=/path/to/pretrained_model_path
DATASTORE_PATH=/path/to/saved_datastore

max_k_grid=(4 8 16 32)
batch_size_grid=(32 32 32 32)
update_freq_grid=(1 1 1 1)
valid_batch_size_grid=(32 32 32 32)

for idx in ${!max_k_grid[*]}
do

  MODEL_RECORD_PATH=/path/to/save/model/train-hid32-maxk${max_k_grid[$idx]}
  TRAINING_RECORD_PATH=/path/to/save/tensorboard/train-hid32-maxk${max_k_grid[$idx]}
  mkdir -p "$TRAINING_RECORD_PATH"

  CUDA_VISIBLE_DEVICES=0 python \
  $PROJECT_PATH/fairseq_cli/train.py \
  $DATA_PATH \
  --log-interval 100 --log-format simple \
  --arch transformer_wmt19_de_en_with_datastore \
  --tensorboard-logdir "$TRAINING_RECORD_PATH" \
  --save-dir "$MODEL_RECORD_PATH" --restore-file "$MODEL_PATH" \
  --reset-dataloader --reset-lr-scheduler --reset-meters --reset-optimizer \
  --validate-interval-updates 100 --save-interval-updates 100 --keep-interval-updates 1 --max-update 5000 --validate-after-updates 1000 \
  --save-interval 10000 --validate-interval 100 \
  --keep-best-checkpoints 1 --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
  --train-subset valid --valid-subset valid --source-lang de --target-lang en \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.001 \
  --max-source-positions 1024 --max-target-positions 1024 \
  --batch-size "${batch_size_grid[$idx]}" --update-freq "${update_freq_grid[$idx]}" --batch-size-valid "${valid_batch_size_grid[$idx]}" \
  --task translation \
  --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --min-lr 3e-05 --lr 0.0003 --clip-norm 1.0 \
  --lr-scheduler reduce_lr_on_plateau --lr-patience 5 --lr-shrink 0.5 \
  --patience 30 --max-epoch 500 \
  --load-knn-datastore --dstore-filename $DATASTORE_PATH --use-knn-datastore \
  --dstore-fp16 --dstore-size $DSTORE_SIZE --probe 32 \
  --knn-sim-func do_not_recomp_l2 \
  --use-gpu-to-search --move-dstore-to-mem --no-load-keys \
  --knn-lambda-type trainable --knn-temperature-type fix --knn-temperature-value 10 --only-train-knn-parameter \
  --knn-k-type trainable --k-lambda-net-hid-size 32 --k-lambda-net-dropout-rate 0.0 --max-k "${max_k_grid[$idx]}" --k "${max_k_grid[$idx]}" \
  --label-count-as-feature
done

The batch size and update-freq should be adjust by yourself depends on your gpu.

Inference with Adaptive kNN-MT

DSTORE_SIZE=3613350
MODEL_PATH=/path/to/trained_model

DATASTORE_PATH=/path/to/datastore
DATA_PATH=/path/to/data
PROJECT_PATH=/path/to/ada_knnmt

OUTPUT_PATH=/path/to/save_output_result

mkdir -p "$OUTPUT_PATH"

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/experimental_generate.py $DATA_PATH \
    --gen-subset test\
    --path "$MODEL_PATH" --arch transformer_wmt19_de_en_with_datastore \
    --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \
    --scoring sacrebleu \
    --batch-size 32 \
    --tokenizer moses --remove-bpe \
    --model-overrides "{'load_knn_datastore': True, 'use_knn_datastore': True,
    'dstore_filename': '$DATASTORE_PATH', 'dstore_size': $DSTORE_SIZE, 'dstore_fp16': True, 'probe': 32,
    'knn_sim_func': 'do_not_recomp_l2', 'use_gpu_to_search': True, 'move_dstore_to_mem': True, 'no_load_keys': True,
    'knn_temperature_type': 'fix', 'knn_temperature_value': 10,}" \
    | tee "$OUTPUT_PATH"/generate.txt

grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src
grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref
grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp
grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok

base NMT inference

We also provide scripts to do NMT and vanilla kNN-MT inference

MODEL_PATH=/path/to/pretrained_model_path/
DATA_PATH=/path/to/fairseq_preprocessed_path/
DATASTORE_PATH=/path/to/saved_datastore/
PROJECT_PATH=/path/to/knnmt/

mkdir -p $OUTPUT_PATH

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/fairseq_cli/generate.py $DATA_PATH\
    --gen-subset test \
    --path $MODEL_PATH \
    --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \
    --scoring sacrebleu \
    --max-tokens 4096 \
    --tokenizer moses --remove-bpe | tee $OUTPUT_PATH/generate.txt

grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src
grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref
grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp
grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok

Vanilla kNN-MT inference

DSTORE_SIZE=3613350
MODEL_PATH=/path/to/pre_trained_model

DATASTORE_PATH=/path/to/datastore
DATA_PATH=/path/to/data
PROJECT_PATH=/path/to/ada_knnmt

OUTPUT_PATH=/path/to/save_output_result

mkdir -p "$OUTPUT_PATH"

CUDA_VISIBLE_DEVICES=0 python $PROJECT_PATH/experimental_generate.py $DATA_PATH \
    --gen-subset test\
    --path $MODEL_PATH --arch transformer_wmt19_de_en_with_datastore \
    --beam 4 --lenpen 0.6 --max-len-a 1.2 --max-len-b 10 --source-lang de --target-lang en \
    --scoring sacrebleu \
    --batch-size 32 \
    --tokenizer moses --remove-bpe \
    --model-overrides "{'load_knn_datastore': True, 'use_knn_datastore': True,
    'dstore_filename': '$DATASTORE_PATH', 'dstore_size': $DSTORE_SIZE, 'dstore_fp16': True, 'k': 8, 'probe': 32,
    'knn_sim_func': 'do_not_recomp_l2', 'use_gpu_to_search': True, 'move_dstore_to_mem': True, 'no_load_keys': True,
    'knn_lambda_type': 'fix', 'knn_lambda_value': 0.7, 'knn_temperature_type': 'fix', 'knn_temperature_value': 10,
     }" \
    | tee "$OUTPUT_PATH"/generate.txt

grep ^S "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/src
grep ^T "$OUTPUT_PATH"/generate.txt | cut -f2- > "$OUTPUT_PATH"/ref
grep ^H "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp
grep ^D "$OUTPUT_PATH"/generate.txt | cut -f3- > "$OUTPUT_PATH"/hyp.detok

We recommend you to use below hyper-parameters to replicate the good vanilla knn-mt results. And note that for our adaptive-knn-mt, we set the temperature as same as below.

IT Medical Law Koran
k 8 4 4 16
lambda 0.7 0.8 0.8 0.8
temperature 10 10 10 100

About

No description, website, or topics provided.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages