Skip to content

iiot-tbb/pk-chat-dialogue

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

38 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PK-Chat

PK-Chat: Pointer Network Guided Knowledge Driven Generative Dialogue Model paper link

Requirements

- python >= 3.6
- paddlepaddle == 1.6.1
- numpy
- nltk
- tqdm
- visualdl >= 1.3.0 (optional)
- regex

Recommend you install to python packages by command: pip install -r requirement.txt

Pre-trained dialogue generation model

You can see the PK-chat Dialog model from there:

mv /path/to/model.tar.gz .
tar xzf model.tar.gz

Fine-tuning

We also provide instructions to fine-tune PK-chat model on different conversation datasets (chit-chat, knowledge grounded dialogues and conversational question answering).

Data preparation

Download data from the link. The tar file contains three processed datasets: DailyDialog, PersonaChat and DSTC7_AVSD.

mv /path/to/data.tar.gz .
tar xzf data.tar.gz

Data format

Our model supports two kinds of data formats for dialogue context: multi and multi_knowledge.

  • multi: multi-turn dialogue context.
u_1 __eou__ u_2 __eou__ ... u_n \t r
  • multi_knowledge: multi-turn dialogue context with background knowledges.
k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r

If you want to use this model on other datasets, you can process your data accordingly.

Train

Fine-tuning the pre-trained model on different ${DATASET}.

# DailyDialog / PersonaChat / DSTC7_AVSD / ACE_Dialog_topic
DATASET=ACE_Dialog_topic
sh scripts/${DATASET}/train.sh

After training, you can find the output folder outputs/${DATASET} (by default). It contatins best.model (best results on validation dataset), hparams.json (hyper-parameters of training script) and trainer.log (training log).

Fine-tuning the pre-trained model on multiple GPUs.

Note: You need to install NCCL library and set up the environment variable LD_LIBRARY properly.

sh scripts/ACE_Dialog_topic/multi_gpu_train.sh

Recommended settings

For the fine-tuning of our pre-trained model, it usually requires about 10 epochs to reach convergence with learning rate = 1e-5 and about 2-3 epochs to reach convergence with learning rate = 5e-5.

GPU Memory batch size max len
16G 6 256
32G 12 256

Infer

Running inference on test dataset.

# DailyDialog / PersonaChat / DSTC7_AVSD / ACE_Dialog_topic
DATASET=ACE_Dialog_topic
sh scripts/${DATASET}/infer.sh

After inference, you can find the output foler outputs/${DATASET}.infer (by default). It contains infer_0.result.json (the inference result), hparams.json (hyper-parameters of inference scipt) and trainer.log (inference log).

If you want to use top-k sampling (beam search by default), you can follow the example script:

sh scripts/DailyDialog/topk_infer.sh

Pipline:

./png/model.jpg

Result

./png/result.jpg

Citation

If you find PK-Chat useful in your work, please cite the following paper:

@misc{deng2023pkchat,
      title={PK-Chat: Pointer Network Guided Knowledge Driven Generative Dialogue Model}, 
      author={Cheng Deng and Bo Tong and Luoyi Fu and Jiaxin Ding and Dexing Cao and Xinbing Wang and Chenghu Zhou},
      year={2023},
      eprint={2304.00592},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Contact information

For help or issues using PK-chat, please submit a GitHub issue.

For personal communication related to PK-chat, please contact Bo Tong (bool_tbb@alumni.sjtu.edu.cn), or Cheng Deng (davendw@sjtu.edu.cn).

About

pointer network based on unitransfomrer

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published