Skip to content

studio-ousia/soseki

Repository files navigation

Sōseki

Sōseki is an implementation of an end-to-end question answering (QA) system.

Currently, Sōseki makes use of Binary Passage Retriever (BPR), an efficient passages retrieval model for a large collection of documents. BPR was originally developed to achieve high computational efficiency of the QA system submitted to the Systems under 6GB track in the NeurIPS 2020 EfficientQA competition.

Installation

# Before installation, upgrade pip and setuptools.
$ pip install -U pip setuptools

# Install the PyTorch package.
# You may want to check the install option for your CUDA environment.
# https://pytorch.org/get-started/locally/
$ pip install 'torch==1.11.0'

# Install other dependencies.
$ pip install -r requirements.txt

# Install the soseki package.
$ pip install .
# Or if you want to install it in editable mode:
$ pip install -e .

Note: If you are using a GPU Environment different from CUDA 10.2, you may need to reinstall PyTorch according to the official documentation.

Example Usage

Before you start, you need to download the datasets available on the DPR repository into <DPR_DATASET_DIR>.

We used 4 GPUs with 12GB memory each for the experiments.

1. Build passage database

$ python build_passage_db.py \
    --passage_file <DPR_DATASET_DIR>/wikipedia_split/psgs_w100.tsv \
    --db_file <WORK_DIR>/passages.db \
    --db_map_size 21000000000

2. Train a biencoder

$ python train_biencoder.py \
    --train_file <DPR_DATASET_DIR>/retriever/nq-train.json \
    --val_file <DPR_DATASET_DIR>/retriever/nq-dev.json \
    --output_dir <WORK_DIR>/biencoder \
    --max_question_length 64 \
    --max_passage_length 192 \
    --num_negative_passages 1 \
    --shuffle_hard_negative_passages \
    --shuffle_normal_negative_passages \
    --base_pretrained_model bert-base-uncased \
    --binary \
    --train_batch_size 16 \
    --eval_batch_size 16 \
    --learning_rate 1e-5 \
    --warmup_proportion 0.1 \
    --gradient_clip_val 2.0 \
    --max_epochs 40 \
    --gpus 4 \
    --precision 16 \
    --strategy ddp

3. Build passage embeddings

$ python build_passage_embeddings.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --output_file <WORK_DIR>/passage_embeddings.idx \
    --max_passage_length 192 \
    --batch_size 2048 \
    --device_ids 0 1 2 3

4. Evaluate the retriever and create datasets for reader

$ mkdir <WORK_DIR>/reader_data

$ python evaluate_retriever.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --qa_file <DPR_DATASET_DIR>/retriever/qas/nq-train.csv \
    --output_file <WORK_DIR>/reader_data/nq_train.jsonl \
    --batch_size 64 \
    --max_question_length 64 \
    --top_k 1 2 5 10 20 50 100 \
    --binary \
    --binary_k 2048 \
    --answer_match_type dpr_string \
    --include_title_in_passage \
    --device_ids 0 1 2 3
# The result should be logged as follows:
# Recall at 1: 0.4993 (39532/79168)
# Recall at 2: 0.6175 (48886/79168)
# Recall at 5: 0.7353 (58213/79168)
# Recall at 10: 0.7919 (62690/79168)
# Recall at 20: 0.8288 (65613/79168)
# Recall at 50: 0.8597 (68061/79168)
# Recall at 100: 0.8751 (69281/79168)

$ python evaluate_retriever.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --qa_file <DPR_DATASET_DIR>/retriever/qas/nq-dev.csv \
    --output_file <WORK_DIR>/reader_data/nq_dev.jsonl \
    --batch_size 64 \
    --max_question_length 64 \
    --top_k 1 2 5 10 20 50 100 \
    --binary \
    --binary_k 2048 \
    --answer_match_type dpr_string \
    --include_title_in_passage \
    --device_ids 0 1 2 3
# The result should be logged as follows:
# Recall at 1: 0.4047 (3544/8757)
# Recall at 2: 0.5143 (4504/8757)
# Recall at 5: 0.6398 (5603/8757)
# Recall at 10: 0.7117 (6232/8757)
# Recall at 20: 0.7595 (6651/8757)
# Recall at 50: 0.8134 (7123/8757)
# Recall at 100: 0.8420 (7373/8757)

$ python evaluate_retriever.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --qa_file <DPR_DATASET_DIR>/retriever/qas/nq-test.csv \
    --output_file <WORK_DIR>/reader_data/nq_test.jsonl \
    --batch_size 64 \
    --max_question_length 64 \
    --top_k 1 2 5 10 20 50 100 \
    --binary \
    --binary_k 2048 \
    --answer_match_type dpr_string \
    --include_title_in_passage \
    --device_ids 0 1 2 3
# The result should be logged as follows:
# Recall at 1: 0.4136 (1493/3610)
# Recall at 2: 0.5208 (1880/3610)
# Recall at 5: 0.6452 (2329/3610)
# Recall at 10: 0.7194 (2597/3610)
# Recall at 20: 0.7737 (2793/3610)
# Recall at 50: 0.8283 (2990/3610)
# Recall at 100: 0.8518 (3075/3610)

5. Train a reader

$ python train_reader.py \
    --train_file <WORK_DIR>/reader_data/nq_train.jsonl \
    --val_file <WORK_DIR>/reader_data/nq_dev.jsonl \
    --output_dir <WORK_DIR>/reader \
    --train_num_passages 24 \
    --eval_num_passages 100 \
    --max_input_length 256 \
    --shuffle_positive_passage \
    --shuffle_negative_passage \
    --num_dataloader_workers 1 \
    --base_pretrained_model bert-base-uncased \
    --answer_normalization_type dpr \
    --train_batch_size 1 \
    --eval_batch_size 2 \
    --learning_rate 1e-5 \
    --warmup_proportion 0.1 \
    --accumulate_grad_batches 4 \
    --gradient_clip_val 2.0 \
    --max_epochs 20 \
    --gpus 4 \
    --precision 16 \
    --strategy ddp

6. Evaluate the reader

$ python evaluate_reader.py \
    --reader_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --test_file <WORK_DIR>/reader_data/nq_dev.jsonl \
    --test_num_passages 100 \
    --test_max_load_passages 100 \
    --test_batch_size 4 \
    --gpus 4 \
    --strategy ddp
# The result should be printed as follows:
# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
# ┃        Test metric        ┃       DataLoader 0        ┃
# ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
# │   test_answer_accuracy    │    0.39294278621673584    │
# │ test_classifier_precision │    0.5889003276824951     │
# └───────────────────────────┴───────────────────────────┘
$ python evaluate_reader.py \
    --reader_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --test_file <WORK_DIR>/reader_data/nq_test.jsonl \
    --test_num_passages 100 \
    --test_max_load_passages 100 \
    --test_batch_size 4 \
    --gpus 4 \
    --strategy ddp
# The result should be printed as follows:
# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
# ┃        Test metric        ┃       DataLoader 0        ┃
# ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
# │   test_answer_accuracy    │    0.3900277018547058     │
# │ test_classifier_precision │    0.5836564898490906     │
# └───────────────────────────┴───────────────────────────┘

7. (optional) Convert the trained models into ONNX format

$ python convert_models_to_onnx.py \
    --biencoder_ckpt_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --reader_ckpt_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --output_dir <WORK_DIR>/onnx

8. Run demo

$ streamlit run demo.py --browser.serverAddress localhost --browser.serverPort 8501 -- \
    --biencoder_ckpt_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --reader_ckpt_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --device cuda:0

or if you have exported the models to ONNX format:

$ streamlit run demo.py --browser.serverAddress localhost --browser.serverPort 8501 -- \
    --onnx_model_dir <WORK_DIR>/onnx \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx

Then open http://localhost:8501.

The demo can also be launched with Docker:

$ docker build -t soseki --build-arg TRANSFORMERS_BASE_MODEL_NAME='bert-base-uncased' .
$ docker run --rm -v $(realpath <WORK_DIR>):/app/model -p 8501:8501 -it soseki \
    streamlit run demo.py --browser.serverAddress localhost --browser.serverPort 8501 -- \
        --onnx_model_dir /app/model/onnx \
        --passage_db_file /app/model/passages.db \
        --passage_embeddings_file /app/model/passage_embeddings.idx

License

Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

Citation

If you find this work useful, please cite the following paper:

Efficient Passage Retrieval with Hashing for Open-domain Question Answering

@inproceedings{yamada2021bpr,
  title={Efficient Passage Retrieval with Hashing for Open-domain Question Answering},
  author={Ikuya Yamada and Akari Asai and Hannaneh Hajishirzi},
  booktitle={ACL},
  year={2021}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published