This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* dialogpt * TA agent * spelling errors * test
- Loading branch information
Emily Dinan
committed
Aug 26, 2020
1 parent
0f9b034
commit c7f4b64
Showing
7 changed files
with
196 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.agents.hugging_face.dict import DialoGPTDictionaryAgent | ||
from parlai.agents.hugging_face.gpt2 import Gpt2Agent, GPT2Decoder, HFGPT2Model | ||
from parlai.utils.misc import warn_once | ||
|
||
try: | ||
from transformers import GPT2Model | ||
except ImportError: | ||
raise ImportError('Please run `pip install transformers`.') | ||
|
||
|
||
############################################ | ||
## Modules | ||
############################################ | ||
|
||
|
||
class DialoGPTDecoder(GPT2Decoder): | ||
""" | ||
DialoGPT Decoder. | ||
This decoder is initialized with the pretrained model from Hugging Face. | ||
""" | ||
|
||
def _init_from_pretrained(self, opt): | ||
# load model | ||
model_sz = opt['gpt2_size'] | ||
fle_key = f'microsoft/DialoGPT-{model_sz}' | ||
return GPT2Model.from_pretrained(fle_key) | ||
|
||
|
||
class DialoGPTModel(HFGPT2Model): | ||
""" | ||
Hugging Face DialoGPT Model. | ||
""" | ||
|
||
def _get_decoder(self, opt, dict): | ||
return DialoGPTDecoder(opt, dict) | ||
|
||
|
||
############################################ | ||
## Agent | ||
############################################ | ||
|
||
|
||
class DialogptAgent(Gpt2Agent): | ||
""" | ||
Hugging Face DialoGPT Agent. | ||
DialoGPT is based on GPT2, which is a multi-layer decoder-only Transformer. | ||
The decoder is initialized with pretrained weights from Hugging Face. | ||
Read more about this model here | ||
<https://huggingface.co/transformers/model_doc/dialogpt.html>. | ||
DialoGPT comes in three sizes: small, medium, large. | ||
If you are finetuning the Dialogpt Agent as a dialogue agent, be sure | ||
to run `--add-special-tokens True`. To examine the performance of the | ||
agent out of the box, run with `--add-special-tokens False`, and make | ||
sure that the batch size is 1. | ||
""" | ||
|
||
@classmethod | ||
def add_cmdline_args(cls, argparser): | ||
agent = argparser.add_argument_group('DialoGPT Args') | ||
agent.add_argument( | ||
'--gpt2-size', | ||
type=str, | ||
default='small', | ||
choices=['small', 'medium', 'large'], | ||
help='Which size model to initialize.', | ||
) | ||
argparser.set_defaults( | ||
delimiter='<|endoftext|>', | ||
history_add_global_end_token='<|endoftext|>', | ||
text_truncate=768, | ||
label_truncate=256, | ||
dict_maxexs=0, # skip building dictionary | ||
) | ||
super(DialogptAgent, cls).add_cmdline_args(argparser) | ||
warn_once('WARNING: this model is in beta and the API is subject to change.') | ||
return agent | ||
|
||
@staticmethod | ||
def dictionary_class(): | ||
""" | ||
Return the dictionary class that this agent expects to use. | ||
Can be overriden if a more complex dictionary is required. | ||
""" | ||
return DialoGPTDictionaryAgent | ||
|
||
def build_model(self, states=None): | ||
""" | ||
Build and return model. | ||
""" | ||
return DialoGPTModel(self.opt, self.dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
import parlai.utils.testing as testing_utils | ||
|
||
|
||
@testing_utils.skipUnlessGPU | ||
class TestDialogptModel(unittest.TestCase): | ||
""" | ||
Test of DialoGPT model. | ||
Checks that DialoGPT gets a certain performance on the integration test task. | ||
""" | ||
|
||
@testing_utils.retry(ntries=3, log_retry=True) | ||
def test_dialogpt(self): | ||
valid, test = testing_utils.train_model( | ||
dict( | ||
task='integration_tests:nocandidate', | ||
model='hugging_face/dialogpt', | ||
add_special_tokens=True, | ||
add_start_token=True, | ||
optimizer='sgd', | ||
learningrate=1, | ||
batchsize=4, | ||
num_epochs=4, | ||
short_final_eval=True, | ||
validation_max_exs=12, | ||
) | ||
) | ||
|
||
self.assertLessEqual(valid['ppl'], 4.0) | ||
self.assertLessEqual(test['ppl'], 4.0) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |