/
crf_vectorized.py
319 lines (243 loc) · 12.9 KB
/
crf_vectorized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import torch
from torch import nn
class CRF(nn.Module):
"""
Linear-chain Conditional Random Field (CRF).
Args:
nb_labels (int): number of labels in your tagset, including special symbols.
bos_tag_id (int): integer representing the beginning of sentence symbol in
your tagset.
eos_tag_id (int): integer representing the end of sentence symbol in your tagset.
pad_tag_id (int, optional): integer representing the pad symbol in your tagset.
If None, the model will treat the PAD as a normal tag. Otherwise, the model
will apply constraints for PAD transitions.
batch_first (bool): Whether the first dimension represents the batch dimension.
"""
def __init__(
self, nb_labels, bos_tag_id, eos_tag_id, pad_tag_id=None, batch_first=True
):
super().__init__()
self.nb_labels = nb_labels
self.BOS_TAG_ID = bos_tag_id
self.EOS_TAG_ID = eos_tag_id
self.PAD_TAG_ID = pad_tag_id
self.batch_first = batch_first
self.transitions = nn.Parameter(torch.empty(self.nb_labels, self.nb_labels))
self.init_weights()
def init_weights(self):
# initialize transitions from a random uniform distribution between -0.1 and 0.1
nn.init.uniform_(self.transitions, -0.1, 0.1)
# enforce contraints (rows=from, columns=to) with a big negative number
# so exp(-10000) will tend to zero
# no transitions allowed to the beginning of sentence
self.transitions.data[:, self.BOS_TAG_ID] = -10000.0
# no transition alloed from the end of sentence
self.transitions.data[self.EOS_TAG_ID, :] = -10000.0
if self.PAD_TAG_ID is not None:
# no transitions from padding
self.transitions.data[self.PAD_TAG_ID, :] = -10000.0
# no transitions to padding
self.transitions.data[:, self.PAD_TAG_ID] = -10000.0
# except if the end of sentence is reached
# or we are already in a pad position
self.transitions.data[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0
self.transitions.data[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
def forward(self, emissions, tags, mask=None):
"""Compute the negative log-likelihood. See `log_likelihood` method."""
nll = -self.log_likelihood(emissions, tags, mask=mask)
return nll
def log_likelihood(self, emissions, tags, mask=None):
"""Compute the probability of a sequence of tags given a sequence of
emissions scores.
Args:
emissions (torch.Tensor): Sequence of emissions for each label.
Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
(seq_len, batch_size, nb_labels) otherwise.
tags (torch.LongTensor): Sequence of labels.
Shape of (batch_size, seq_len) if batch_first is True,
(seq_len, batch_size) otherwise.
mask (torch.FloatTensor, optional): Tensor representing valid positions.
If None, all positions are considered valid.
Shape of (batch_size, seq_len) if batch_first is True,
(seq_len, batch_size) otherwise.
Returns:
torch.Tensor: the log-likelihoods for each sequence in the batch.
Shape of (batch_size,)
"""
# fix tensors order by setting batch as the first dimension
if not self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.float)
scores = self._compute_scores(emissions, tags, mask=mask)
partition = self._compute_log_partition(emissions, mask=mask)
return torch.sum(scores - partition)
def decode(self, emissions, mask=None):
"""Find the most probable sequence of labels given the emissions using
the Viterbi algorithm.
Args:
emissions (torch.Tensor): Sequence of emissions for each label.
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
(seq_len, batch_size, nb_labels) otherwise.
mask (torch.FloatTensor, optional): Tensor representing valid positions.
If None, all positions are considered valid.
Shape (batch_size, seq_len) if batch_first is True,
(seq_len, batch_size) otherwise.
Returns:
torch.Tensor: the viterbi score for the for each batch.
Shape of (batch_size,)
list of lists: the best viterbi sequence of labels for each batch.
"""
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.float)
scores, sequences = self._viterbi_decode(emissions, mask)
return scores, sequences
def _compute_scores(self, emissions, tags, mask):
"""Compute the scores for a given batch of emissions with their tags.
Args:
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
tags (Torch.LongTensor): (batch_size, seq_len)
mask (Torch.FloatTensor): (batch_size, seq_len)
Returns:
torch.Tensor: Scores for each batch.
Shape of (batch_size,)
"""
batch_size, seq_length = tags.shape
scores = torch.zeros(batch_size)
# save first and last tags to be used later
first_tags = tags[:, 0]
last_valid_idx = mask.int().sum(1) - 1
last_tags = tags.gather(1, last_valid_idx.unsqueeze(1)).squeeze()
# add the transition from BOS to the first tags for each batch
t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
# add the [unary] emission scores for the first tags for each batch
# for all batches, the first word, see the correspondent emissions
# for the first tags (which is a list of ids):
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
# the scores for a word is just the sum of both scores
scores += e_scores + t_scores
# now lets do this for each remaining word
for i in range(1, seq_length):
# we could: iterate over batches, check if we reached a mask symbol
# and stop the iteration, but vecotrizing is faster due to gpu,
# so instead we perform an element-wise multiplication
is_valid = mask[:, i]
previous_tags = tags[:, i - 1]
current_tags = tags[:, i]
# calculate emission and transition scores as we did before
e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
t_scores = self.transitions[previous_tags, current_tags]
# apply the mask
e_scores = e_scores * is_valid
t_scores = t_scores * is_valid
scores += e_scores + t_scores
# add the transition from the end tag to the EOS tag for each batch
scores += self.transitions[last_tags, self.EOS_TAG_ID]
return scores
def _compute_log_partition(self, emissions, mask):
"""Compute the partition function in log-space using the forward-algorithm.
Args:
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
mask (Torch.FloatTensor): (batch_size, seq_len)
Returns:
torch.Tensor: the partition scores for each batch.
Shape of (batch_size,)
"""
batch_size, seq_length, nb_labels = emissions.shape
# in the first iteration, BOS will have all the scores
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0]
for i in range(1, seq_length):
# (bs, nb_labels) -> (bs, 1, nb_labels)
e_scores = emissions[:, i].unsqueeze(1)
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
t_scores = self.transitions.unsqueeze(0)
# (bs, nb_labels) -> (bs, nb_labels, 1)
a_scores = alphas.unsqueeze(2)
scores = e_scores + t_scores + a_scores
new_alphas = torch.logsumexp(scores, dim=1)
# set alphas if the mask is valid, otherwise keep the current values
is_valid = mask[:, i].unsqueeze(-1)
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
# add the scores for the final transition
last_transition = self.transitions[:, self.EOS_TAG_ID]
end_scores = alphas + last_transition.unsqueeze(0)
# return a *log* of sums of exps
return torch.logsumexp(end_scores, dim=1)
def _viterbi_decode(self, emissions, mask):
"""Compute the viterbi algorithm to find the most probable sequence of labels
given a sequence of emissions.
Args:
emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
mask (Torch.FloatTensor): (batch_size, seq_len)
Returns:
torch.Tensor: the viterbi score for the for each batch.
Shape of (batch_size,)
list of lists of ints: the best viterbi sequence of labels for each batch
"""
batch_size, seq_length, nb_labels = emissions.shape
# in the first iteration, BOS will have all the scores and then, the max
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0]
backpointers = []
for i in range(1, seq_length):
# (bs, nb_labels) -> (bs, 1, nb_labels)
e_scores = emissions[:, i].unsqueeze(1)
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
t_scores = self.transitions.unsqueeze(0)
# (bs, nb_labels) -> (bs, nb_labels, 1)
a_scores = alphas.unsqueeze(2)
# combine current scores with previous alphas
scores = e_scores + t_scores + a_scores
# so far is exactly like the forward algorithm,
# but now, instead of calculating the logsumexp,
# we will find the highest score and the tag associated with it
max_scores, max_score_tags = torch.max(scores, dim=1)
# set alphas if the mask is valid, otherwise keep the current values
is_valid = mask[:, i].unsqueeze(-1)
alphas = is_valid * max_scores + (1 - is_valid) * alphas
# add the max_score_tags for our list of backpointers
# max_scores has shape (batch_size, nb_labels) so we transpose it to
# be compatible with our previous loopy version of viterbi
backpointers.append(max_score_tags.t())
# add the scores for the final transition
last_transition = self.transitions[:, self.EOS_TAG_ID]
end_scores = alphas + last_transition.unsqueeze(0)
# get the final most probable score and the final most probable tag
max_final_scores, max_final_tags = torch.max(end_scores, dim=1)
# find the best sequence of labels for each sample in the batch
best_sequences = []
emission_lengths = mask.int().sum(dim=1)
for i in range(batch_size):
# recover the original sentence length for the i-th sample in the batch
sample_length = emission_lengths[i].item()
# recover the max tag for the last timestep
sample_final_tag = max_final_tags[i].item()
# limit the backpointers until the last but one
# since the last corresponds to the sample_final_tag
sample_backpointers = backpointers[: sample_length - 1]
# follow the backpointers to build the sequence of labels
sample_path = self._find_best_path(i, sample_final_tag, sample_backpointers)
# add this path to the list of best sequences
best_sequences.append(sample_path)
return max_final_scores, best_sequences
def _find_best_path(self, sample_id, best_tag, backpointers):
"""Auxiliary function to find the best path sequence for a specific sample.
Args:
sample_id (int): sample index in the range [0, batch_size)
best_tag (int): tag which maximizes the final score
backpointers (list of lists of tensors): list of pointers with
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
represents the length of the ith sample in the batch
Returns:
list of ints: a list of tag indexes representing the bast path
"""
# add the final best_tag to our best path
best_path = [best_tag]
# traverse the backpointers in backwards
for backpointers_t in reversed(backpointers):
# recover the best_tag at this timestep
best_tag = backpointers_t[best_tag][sample_id].item()
# append to the beginning of the list so we don't need to reverse it later
best_path.insert(0, best_tag)
return best_path