Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add DialoGPT Agent (#3007)
Browse files Browse the repository at this point in the history
* dialogpt

* TA agent

* spelling errors

* test
  • Loading branch information
Emily Dinan committed Aug 26, 2020
1 parent 0f9b034 commit c7f4b64
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 16 deletions.
25 changes: 23 additions & 2 deletions parlai/agents/hugging_face/README.md
Expand Up @@ -24,5 +24,26 @@ Enter Your Message: Parrots are
parlai train_model -m hugging_face/gpt2 --add-special-tokens True --add-start-token True --gpt2-size medium -t convai2 -bs 2 -mf <modelfile>
```

## Other models
_Other models are coming soon -- stay tuned!_
## DialoGPT

To use DialoGPT, run your command with the flag: `-m hugging_face/dialogpt`.

### Examples
**Talk to DialoGPT large in interactive mode, with beam size 10, 3-gram beam blocking, and minimum beam length 25:**
```bash
parlai interactive -m hugging_face/dialogpt --add-special-tokens False --gpt2-size large --inference beam --beam-size 10 --beam-context-block-ngram 3 --beam-block-ngram 3 --beam-min-length 25
```
_Note:_ In the above command, we must have the flag `--add-special-tokens False` if we want to use the model _without_ finetuning it.

Here is example output from the above command:
```
Enter Your Message: What do you think of parrots?
[Dialogpt]: I love parrots. They are the best. I love them so much. I wish I had a pet parrot.
```


**Fine-tune DialoGPT medium on the ConvAI2 task:**
```bash
parlai train_model -m hugging_face/dialogpt --add-special-tokens True --delimiter '\n' --add-start-token True --gpt2-size medium -t convai2 -bs 2 -mf <modelfile>
```
_Note:_ In the above command, we change the default delimiter from `--delimiter '<|endoftext|>'`, as a personal choice.
101 changes: 101 additions & 0 deletions parlai/agents/hugging_face/dialogpt.py
@@ -0,0 +1,101 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.agents.hugging_face.dict import DialoGPTDictionaryAgent
from parlai.agents.hugging_face.gpt2 import Gpt2Agent, GPT2Decoder, HFGPT2Model
from parlai.utils.misc import warn_once

try:
from transformers import GPT2Model
except ImportError:
raise ImportError('Please run `pip install transformers`.')


############################################
## Modules
############################################


class DialoGPTDecoder(GPT2Decoder):
"""
DialoGPT Decoder.
This decoder is initialized with the pretrained model from Hugging Face.
"""

def _init_from_pretrained(self, opt):
# load model
model_sz = opt['gpt2_size']
fle_key = f'microsoft/DialoGPT-{model_sz}'
return GPT2Model.from_pretrained(fle_key)


class DialoGPTModel(HFGPT2Model):
"""
Hugging Face DialoGPT Model.
"""

def _get_decoder(self, opt, dict):
return DialoGPTDecoder(opt, dict)


############################################
## Agent
############################################


class DialogptAgent(Gpt2Agent):
"""
Hugging Face DialoGPT Agent.
DialoGPT is based on GPT2, which is a multi-layer decoder-only Transformer.
The decoder is initialized with pretrained weights from Hugging Face.
Read more about this model here
<https://huggingface.co/transformers/model_doc/dialogpt.html>.
DialoGPT comes in three sizes: small, medium, large.
If you are finetuning the Dialogpt Agent as a dialogue agent, be sure
to run `--add-special-tokens True`. To examine the performance of the
agent out of the box, run with `--add-special-tokens False`, and make
sure that the batch size is 1.
"""

@classmethod
def add_cmdline_args(cls, argparser):
agent = argparser.add_argument_group('DialoGPT Args')
agent.add_argument(
'--gpt2-size',
type=str,
default='small',
choices=['small', 'medium', 'large'],
help='Which size model to initialize.',
)
argparser.set_defaults(
delimiter='<|endoftext|>',
history_add_global_end_token='<|endoftext|>',
text_truncate=768,
label_truncate=256,
dict_maxexs=0, # skip building dictionary
)
super(DialogptAgent, cls).add_cmdline_args(argparser)
warn_once('WARNING: this model is in beta and the API is subject to change.')
return agent

@staticmethod
def dictionary_class():
"""
Return the dictionary class that this agent expects to use.
Can be overriden if a more complex dictionary is required.
"""
return DialoGPTDictionaryAgent

def build_model(self, states=None):
"""
Build and return model.
"""
return DialoGPTModel(self.opt, self.dict)
10 changes: 10 additions & 0 deletions parlai/agents/hugging_face/dict.py
Expand Up @@ -111,3 +111,13 @@ def override_special_tokens(self, opt):
self.ind2tok[self.end_idx] = self.end_token
self.ind2tok[self.start_idx] = self.start_token
self.ind2tok[self.null_idx] = self.null_token


class DialoGPTDictionaryAgent(Gpt2DictionaryAgent):
def get_tokenizer(self, opt):
"""
Instantiate tokenizer.
"""
model_sz = opt['gpt2_size']
fle_key = f'microsoft/DialoGPT-{model_sz}'
return GPT2Tokenizer.from_pretrained(fle_key)
26 changes: 16 additions & 10 deletions parlai/agents/hugging_face/gpt2.py
Expand Up @@ -30,15 +30,7 @@ class GPT2Decoder(torch.nn.Module):

def __init__(self, opt, dict):
super().__init__()
# load model
model_sz = opt['gpt2_size']
if model_sz == 'small':
fle_key = 'gpt2'
elif model_sz == 'distilgpt2':
fle_key = 'distilgpt2'
else:
fle_key = f'gpt2-{model_sz}'
self.transformer = GPT2Model.from_pretrained(fle_key)
self.transformer = self._init_from_pretrained(opt)
# add special tokens
self.start_idx = dict.start_idx
self.null_idx = dict.null_idx
Expand All @@ -49,6 +41,17 @@ def __init__(self, opt, dict):
# use cuda
self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()

def _init_from_pretrained(self, opt):
# load model
model_sz = opt['gpt2_size']
if model_sz == 'small':
fle_key = 'gpt2'
elif model_sz == 'distilgpt2':
fle_key = 'distilgpt2'
else:
fle_key = f'gpt2-{model_sz}'
return GPT2Model.from_pretrained(fle_key)

def forward(self, input, encoder_state, incr_state=None):
attention_mask = None
if incr_state is None:
Expand Down Expand Up @@ -101,7 +104,7 @@ def __init__(self, opt, dict):

# init the model
self.encoder = IdentityLayer()
self.decoder = GPT2Decoder(opt, dict)
self.decoder = self._get_decoder(opt, dict)
self.config = self.decoder.transformer.config
self.lm_head = torch.nn.Linear(
self.config.n_embd, self.config.vocab_size, bias=False
Expand All @@ -112,6 +115,9 @@ def __init__(self, opt, dict):
# used to reverse concatenation of context and labels
self.text_lengths = None

def _get_decoder(self, opt, dict):
return GPT2Decoder(opt, dict)

def _tie_weights(self, output_embeddings, input_embeddings):
output_embeddings.weight = input_embeddings.weight

Expand Down
7 changes: 4 additions & 3 deletions parlai/agents/hugging_face/hugging_face.py
Expand Up @@ -7,8 +7,9 @@
"""
Integration with Hugging Face Transformers.
Please see <https://huggingface.co/transformers/>. Currently, the only implementation is
GPT2. To use this model, run with `-m hugging_face/gpt2`.
Please see <https://huggingface.co/transformers/>. Currently, the only implementations
are GPT2 and DialoGPT. To use these models, run with `-m hugging_face/gpt2` or `-m
hugging_face/dialogpt`.
"""
try:
import transformers # noqa: F401
Expand All @@ -20,5 +21,5 @@ class HuggingFaceAgent:
def __init__(self, opt, shared=None):
raise RuntimeError(
'`-m hugging_face` is not a valid choice. Please run with '
'`-m hugging_face/gpt2`.'
'`-m hugging_face/gpt2` or `-m hugging_face/dialogpt`.'
)
2 changes: 1 addition & 1 deletion parlai/core/torch_agent.py
Expand Up @@ -646,7 +646,6 @@ def add_cmdline_args(cls, argparser):
type='nonestr',
default=None,
hidden=True,
choices=[None, 'end'],
help='Add special token to the end of history encoding.',
)
agent.add_argument(
Expand Down Expand Up @@ -1360,6 +1359,7 @@ def _set_text_vec(self, obs, history, truncate):
obs['text_vec'], truncate, truncate_left
)
obs.force_set('text_vec', torch.LongTensor(truncated_vec))

return obs

def _set_label_vec(self, obs, add_start, add_end, truncate):
Expand Down
41 changes: 41 additions & 0 deletions tests/nightly/gpu/test_dialogpt.py
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import parlai.utils.testing as testing_utils


@testing_utils.skipUnlessGPU
class TestDialogptModel(unittest.TestCase):
"""
Test of DialoGPT model.
Checks that DialoGPT gets a certain performance on the integration test task.
"""

@testing_utils.retry(ntries=3, log_retry=True)
def test_dialogpt(self):
valid, test = testing_utils.train_model(
dict(
task='integration_tests:nocandidate',
model='hugging_face/dialogpt',
add_special_tokens=True,
add_start_token=True,
optimizer='sgd',
learningrate=1,
batchsize=4,
num_epochs=4,
short_final_eval=True,
validation_max_exs=12,
)
)

self.assertLessEqual(valid['ppl'], 4.0)
self.assertLessEqual(test['ppl'], 4.0)


if __name__ == '__main__':
unittest.main()

0 comments on commit c7f4b64

Please sign in to comment.