/
modify_bert.diff
121 lines (120 loc) · 6.12 KB
/
modify_bert.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
diff --git a/transformers/__init__.py b/transformers/__init__.py
index fbc92f0..39320f4 100644
--- a/transformers/__init__.py
+++ b/transformers/__init__.py
@@ -67,6 +67,7 @@ if is_torch_available():
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering,
+ BertForVQR,
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py
index 8c92241..0882dfd 100644
--- a/transformers/modeling_bert.py
+++ b/transformers/modeling_bert.py
@@ -1151,3 +1151,104 @@ class BertForQuestionAnswering(BertPreTrainedModel):
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
+
+class BertForVQR(BertPreTrainedModel):
+ def __init__(self, config, num_labels=2, q_relevance=True, r_relevance=True, answer_extraction=True, answer_verification=False):
+ super(BertForVQR, self).__init__(config)
+ self.num_labels = num_labels
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # set up which tasks the network will do
+ self.q_relevance = q_relevance
+ self.r_relevance = r_relevance
+ self.answer_extraction = answer_extraction
+
+ if self.q_relevance:
+ self.q_relevance_classifier = nn.Linear(config.hidden_size, num_labels)
+ if self.r_relevance:
+ self.r_relevance_classifier = nn.Linear(config.hidden_size, num_labels)
+ if self.answer_extraction:
+ self.span_classifier = nn.Linear(config.hidden_size, 2)
+
+ self.init_weights()
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None,
+ q_relevance_ids=None, r_relevance_ids=None,
+ start_positions=None, end_positions=None, original_examples=None):
+ output = self.bert(input_ids, attention_mask, token_type_ids)
+ encoded_layers, pooled_output = output
+
+ def classify_confusion(input_type='r', weighting=[1,1], answer_span_loss=None):
+ if input_type == 'r':
+ logits = self.r_relevance_classifier(self.dropout(pooled_output))
+ labels = r_relevance_ids
+ elif input_type == 'q':
+ logits = self.q_relevance_classifier(self.dropout(pooled_output))
+ labels = q_relevance_ids
+
+ loss = 0
+ if labels is not None:
+ weights = labels.new(weighting).float()
+ if self.num_labels > 1:
+ loss_fct = CrossEntropyLoss(weight=weights)
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ else:
+ loss_fct = BCEWithLogitsLoss(weight=weights)
+ loss = loss_fct(logits.view(-1), labels.view(-1).float())
+
+ if self.num_labels == 1:
+ logits = nn.functional.sigmoid(logits)
+ else:
+ logits = torch.nn.functional.softmax(logits, dim=1)
+ if answer_span_loss is not None:
+ logit_adjustment = self.loss_multiplier*torch.stack([torch.zeros_like(answer_span_loss), answer_span_loss], dim=1)
+ logit_adjustment[:, 1] = logit_adjustment[:, 1] + self.loss_bias
+ logit_adjustment = logit_adjustment.where(answer_span_loss.view(-1, 1) != -1, answer_span_loss.new([[0, 0]]))
+ logits = logits + logit_adjustment
+ logits = torch.nn.functional.softmax(logits, dim=1)
+ return loss, logits
+
+ def extract_answer(loss_mask=None):
+ logits = self.span_classifier(encoded_layers)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+ total_loss = 0
+ if start_positions is not None and end_positions is not None:
+ # ignore start/end positions outside model inputs
+ ignored_index = start_logits.size(1)
+ # avoid modifying input
+ tmp_start_positions = start_positions.clamp(0, ignored_index)
+ tmp_end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, tmp_start_positions)
+ end_loss = loss_fct(end_logits, tmp_end_positions)
+ if loss_mask is not None:
+ start_loss = torch.where(loss_mask == 1, torch.zeros_like(start_loss), start_loss)
+ end_loss = torch.where(loss_mask == 1, torch.zeros_like(end_loss), end_loss)
+
+ total_loss = torch.mean(start_loss + end_loss)
+ return total_loss, start_logits, end_logits
+
+ q_loss, r_loss, span_loss = 0, 0, 0
+ retvals = {}
+ if self.q_relevance:
+ # it's 3x more likely that a question is valid
+ q_loss, q_logits = classify_confusion('q', weighting=[1,3])
+ retvals['q_logits'] = q_logits
+
+ answer_span_losses = None
+ if self.answer_extraction:
+ span_loss, start_logits, end_logits = extract_answer(r_relevance_ids)
+ retvals['span_logits'] = [start_logits, end_logits]
+
+ if self.r_relevance:
+ # it's 2x more likely that a response is valid
+ r_loss, r_logits = classify_confusion('r', weighting=[1,2], answer_span_loss=answer_span_losses)
+ retvals['r_logits'] = r_logits
+
+ retvals['loss'] = q_loss + r_loss + span_loss
+ return tuple([retvals[key] for key in sorted(retvals.keys())])
+