/
bert_example.py
225 lines (192 loc) · 7.99 KB
/
bert_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build BERT Examples from text (source, target) pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from bert import tokenization
import tagging
import tagging_converter
import tensorflow as tf
from typing import Mapping, MutableSequence, Optional, Sequence, Text
class BertExample(object):
"""Class for training and inference examples for BERT.
Attributes:
editing_task: The EditingTask from which this example was created. Needed
when realizing labels predicted for this example.
features: Feature dictionary.
"""
def __init__(self, input_ids,
input_mask,
segment_ids, labels,
labels_mask,
token_start_indices,
task, default_label):
input_len = len(input_ids)
if not (input_len == len(input_mask) and input_len == len(segment_ids) and
input_len == len(labels) and input_len == len(labels_mask)):
raise ValueError(
'All feature lists should have the same length ({})'.format(
input_len))
self.features = collections.OrderedDict([
('input_ids', input_ids),
('input_mask', input_mask),
('segment_ids', segment_ids),
('labels', labels),
('labels_mask', labels_mask),
])
self._token_start_indices = token_start_indices
self.editing_task = task
self._default_label = default_label
def pad_to_max_length(self, max_seq_length, pad_token_id):
"""Pad the feature vectors so that they all have max_seq_length.
Args:
max_seq_length: The length that features will have after padding.
pad_token_id: input_ids feature is padded with this ID, other features
with ID 0.
"""
pad_len = max_seq_length - len(self.features['input_ids'])
for key in self.features:
pad_id = pad_token_id if key == 'input_ids' else 0
self.features[key].extend([pad_id] * pad_len)
if len(self.features[key]) != max_seq_length:
raise ValueError('{} has length {} (should be {}).'.format(
key, len(self.features[key]), max_seq_length))
def to_tf_example(self):
"""Returns this object as a tf.Example."""
def int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
tf_features = collections.OrderedDict([
(key, int_feature(val)) for key, val in self.features.items()
])
return tf.train.Example(features=tf.train.Features(feature=tf_features))
def get_token_labels(self):
"""Returns labels/tags for the original tokens, not for wordpieces."""
labels = []
for idx in self._token_start_indices:
# For unmasked and untruncated tokens, use the label in the features, and
# for the truncated tokens, use the default label.
if (idx < len(self.features['labels']) and
self.features['labels_mask'][idx]):
labels.append(self.features['labels'][idx])
else:
labels.append(self._default_label)
return labels
class BertExampleBuilder(object):
"""Builder class for BertExample objects."""
def __init__(self, label_map, vocab_file,
max_seq_length, do_lower_case,
converter):
"""Initializes an instance of BertExampleBuilder.
Args:
label_map: Mapping from tags to tag IDs.
vocab_file: Path to BERT vocabulary file.
max_seq_length: Maximum sequence length.
do_lower_case: Whether to lower case the input text. Should be True for
uncased models and False for cased models.
converter: Converter from text targets to tags.
"""
self._label_map = label_map
self._tokenizer = tokenization.FullTokenizer(vocab_file,
do_lower_case=do_lower_case)
self._max_seq_length = max_seq_length
self._converter = converter
self._pad_id = self._get_pad_id()
self._keep_tag_id = self._label_map['KEEP']
def build_bert_example(
self,
sources,
target = None,
use_arbitrary_target_ids_for_infeasible_examples = False
):
"""Constructs a BERT Example.
Args:
sources: List of source texts.
target: Target text or None when building an example during inference.
use_arbitrary_target_ids_for_infeasible_examples: Whether to build an
example with arbitrary target ids even if the target can't be obtained
via tagging.
Returns:
BertExample, or None if the conversion from text to tags was infeasible
and use_arbitrary_target_ids_for_infeasible_examples == False.
"""
# Compute target labels.
task = tagging.EditingTask(sources)
if target is not None:
tags = self._converter.compute_tags(task, target)
if not tags:
if use_arbitrary_target_ids_for_infeasible_examples:
# Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is
# unlikely to be predicted by chance.
tags = [tagging.Tag('KEEP') if i % 2 == 0 else tagging.Tag('DELETE')
for i, _ in enumerate(task.source_tokens)]
else:
return None
else:
# If target is not provided, we set all target labels to KEEP.
tags = [tagging.Tag('KEEP') for _ in task.source_tokens]
labels = [self._label_map[str(tag)] for tag in tags]
tokens, labels, token_start_indices = self._split_to_wordpieces(
task.source_tokens, labels)
tokens = self._truncate_list(tokens)
labels = self._truncate_list(labels)
input_tokens = ['[CLS]'] + tokens + ['[SEP]']
labels_mask = [0] + [1] * len(labels) + [0]
labels = [0] + labels + [0]
input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
example = BertExample(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
labels=labels,
labels_mask=labels_mask,
token_start_indices=token_start_indices,
task=task,
default_label=self._keep_tag_id)
example.pad_to_max_length(self._max_seq_length, self._pad_id)
return example
def _split_to_wordpieces(self, tokens, labels):
"""Splits tokens (and the labels accordingly) to WordPieces.
Args:
tokens: Tokens to be split.
labels: Labels (one per token) to be split.
Returns:
3-tuple with the split tokens, split labels, and the indices of the
WordPieces that start a token.
"""
bert_tokens = [] # Original tokens split into wordpieces.
bert_labels = [] # Label for each wordpiece.
# Index of each wordpiece that starts a new token.
token_start_indices = []
for i, token in enumerate(tokens):
# '+ 1' is because bert_tokens will be prepended by [CLS] token later.
token_start_indices.append(len(bert_tokens) + 1)
pieces = self._tokenizer.tokenize(token)
bert_tokens.extend(pieces)
bert_labels.extend([labels[i]] * len(pieces))
return bert_tokens, bert_labels, token_start_indices
def _truncate_list(self, x):
"""Returns truncated version of x according to the self._max_seq_length."""
# Save two slots for the first [CLS] token and the last [SEP] token.
return x[:self._max_seq_length - 2]
def _get_pad_id(self):
"""Returns the ID of the [PAD] token (or 0 if it's not in the vocab)."""
try:
return self._tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
except KeyError:
return 0