This repository has been archived by the owner on Jul 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transformer_xl_from_scratch_2.1.py
307 lines (258 loc) · 13.2 KB
/
transformer_xl_from_scratch_2.1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
This is a tutorial on how to define a simple XLNet which has a single
attention head from scratch.
src https://github.com/kimiyoung/transformer-xl/tree/master/pytorch
"""
import torch
import torch.nn as nn
from utils.embedding import RelativePositionalEmbedding
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
class MultiHeadAttention(nn.Module):
def __init__(self, d_input, d_inner, n_heads=4, dropout=0.1, dropouta=.0):
super(MultiHeadAttention, self).__init__()
self.d_input = d_input
self.d_inner = d_inner
self.n_heads = n_heads
# this layer applies the linear transformation required
# for the keys and values for all heads at once for efficiency
self.linear_kv = nn.Linear(
d_input,
(d_inner * n_heads * 2), # 2 is for keys and values
bias=False, # we don't apply bias, making this a simple matrix multiplication
)
# for queries (will not be concatenated with memorized states so separate)
self.linear_q = nn.Linear(
d_input, d_inner * n_heads,
bias=False
)
# for positional embeddings
self.linear_p = nn.Linear(
d_input, d_inner * n_heads,
bias=False
)
self.scale = 1 / (d_inner ** 0.5) # for scaled dot product attention
self.dropa = nn.Dropout(dropouta)
# we will use this to project back to the input dimension
self.lout = nn.Linear(self.d_inner * self.n_heads, self.d_input, bias=False)
self.norm = nn.LayerNorm(self.d_input)
self.dropo = nn.Dropout(dropout)
@staticmethod
def _rel_shift(x):
zero_pad = torch.zeros((x.size(0), 1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
return (torch.cat([zero_pad, x], dim=1)
.view(x.size(1) + 1, x.size(0), *x.size()[2:])[1:]
.view_as(x))
def forward(self, input_, # (cur_seq, b, d_in)
pos_embs, # (cur_seq + prev_seq, d_in)
memory, # (prev_seq, b, d_in)
u, # (H, d)
v, # (H, d)
mask=None,
):
"""
pos_embs: we pass the positional embeddings in separately
because we need to handle relative positions
input shape: (seq, bs, self.d_input)
pos_embs shape: (seq + prev_seq, bs, self.d_input)
output shape: (seq, bs, self.d_input)
"""
cur_seq = input_.shape[0] # sequence length of current segment
prev_seq = memory.shape[0] # sequence length of previous segment
H, d = self.n_heads, self.d_inner
input_with_memory = torch.cat([memory, input_], dim=0) # concatenate recurrent memory
# across sequence dimension
# we will use the following symbols to represent the shape of the tensors
# cs: current sequence length, b: batch, H: number of heads
# d: inner dimension, ps: previous sequence length
# The key and value are now conditioned on the preceding context
k_tfmd, v_tfmd = \
torch.chunk(self.linear_kv(input_with_memory), 2, dim=-1) # (cs + ps, b, H * d)
q_tfmd = self.linear_q(input_) # (cs, b, H * d)
# apply scaled dot product attention
# look at the following dimensions carefully, since this is the key operation
# in the Transformer/Transformer XL architecture
_, bs, _ = q_tfmd.shape
assert bs == k_tfmd.shape[1]
# content-based attention term ((a) + (c) in the paper)
# this is the standard attention term in the original Transformer, except without positional embeddings
# which are handled separately in the Transformer XL (see below)
# here, i corresponds to the number of queries = number of current inputs/targets (seq-wise)
# j corresponds to the number of key/values = number of vectors that we can use to compute the
# vector for each query
content_attn = torch.einsum("ibhd,jbhd->ijbh", (
(q_tfmd.view(cur_seq, bs, H, d) + # (a)
u), # (c): u represents the global (independent of the query)
# bias towards certain key/values = words
# Note: maybe this could be a per-attention head parameter?
k_tfmd.view(cur_seq + prev_seq, bs, H, d) # There is no positional information to be found here
)) # (cs, cs + ps, b, H)
# position-based attention term ((b) + (d) in the paper)
# this attention is solely based on the position of the key/values
# (i.e. it does not take the content of the key/values into account)
p_tfmd = self.linear_p(pos_embs) # (cs + ps, b, H * d)
position_attn = torch.einsum("ibhd,jhd->ijbh", (
(q_tfmd.view(cur_seq, bs, H, d) + # (b)
v), # (d): v represents the global (independent of the query)
# bias towards certain positions
p_tfmd.view(cur_seq + prev_seq, H, d) # Notice there is not content information
# regarding keys and values here!
)) # (cs, cs + ps, b, H)
# Compute positional attention efficiently
position_attn = self._rel_shift(position_attn)
# the attention is the sum of content-based and position-based attention
attn = content_attn + position_attn
if mask is not None and mask.any().item():
attn = attn.masked_fill(
mask[..., None], -float('inf'))
attn = torch.softmax(attn * self.scale, # rescale to prevent values from exploding
dim=1) # normalize across the value sequence dimension
attn = self.dropa(attn)
attn_weighted_values = (torch.einsum("ijbh,jbhd->ibhd",
(attn, # (cs, cs + ps, b, H)
v_tfmd.view(cur_seq + prev_seq, bs, H, d), # (cs + ps, b, H, d)
)) # (cs, b, H, d)
.contiguous() # we need to change the memory layout to make `view` work
.view(cur_seq, bs, H * d)) # (cs, b, H * d)
# Project back to input dimension and add residual connection
output = input_ + self.dropo(self.lout(attn_weighted_values))
output = self.norm(output)
return output
mha = MultiHeadAttention(32, 17, n_heads=4)
inpt = torch.rand(7, 3, 32)
pos = torch.rand(13, 32)
mem = torch.rand(6, 3, 32)
u, v = torch.rand(4, 17), torch.rand(4, 17)
x1 = mha(inpt, pos, mem, u, v)
print(x1.shape)
'''Building the decoder
To construct the decoder block, all we need in addition to the MultiHeadAttention layer
is the Positionwise Feed Forward layer.
'''
class PositionwiseFF(nn.Module):
def __init__(self, d_input, d_inner, dropout):
super(PositionwiseFF, self).__init__()
self.d_input = d_input
self.d_inner = d_inner
self.dropout = dropout
self.ff = nn.Sequential(
nn.Linear(d_input, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_input),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_input)
def forward(self, input_, #: torch.FloatTensor, # (cur_seq, bs, d_input)
): # -> torch.FloatTensor: # (cur_seq, bs, d_input)
ff_out = self.ff(input_)
output = self.layer_norm(input_ + ff_out)
return output
class DecoderBlock(nn.Module):
def __init__(self, n_heads, d_input,
d_head_inner, d_ff_inner,
dropout, dropouta=0.):
super(DecoderBlock, self).__init__()
self.mha = MultiHeadAttention(d_input, d_head_inner, n_heads=n_heads,
dropout=dropout, dropouta=dropouta)
self.ff = PositionwiseFF(d_input, d_ff_inner, dropout)
def forward(self, input_, #: torch.FloatTensor, # (cur_seq, bs, d_input)
pos_embs, #: torch.FloatTensor, # (cur_seq + prev_seq, d_input),
u, #: torch.FloatTensor, # (H, d_input),
v, #: torch.FloatTensor, # (H, d_input),
mask=None,
mems=None,
):
return self.ff(self.mha(input_, pos_embs, mems, u, v, mask=mask))
'''The full Transformer XL
'''
class StandardWordEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, div_val=1, sample_softmax=False):
super(StandardWordEmbedding, self).__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.scale = embedding_dim ** 0.5
def forward(self, input_): #: torch.LongTensor):
return self.embedding(input_) * self.scale
class TransformerXL(nn.Module):
def __init__(self, num_embeddings, n_layers, n_heads,
d_model, d_head_inner, d_ff_inner,
dropout=0.1, dropouta=0.,
seq_len=0, mem_len=0):
super(TransformerXL, self).__init__()
self.n_layers, self.n_heads, self.d_model, self.d_head_inner, self.d_ff_inner = \
n_layers, n_heads, d_model, d_head_inner, d_ff_inner
# Embedding layers
self.word_embs = StandardWordEmbedding(num_embeddings, d_model)
self.pos_embs = RelativePositionalEmbedding(d_model)
# Core transformer
self.drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, d_head_inner=d_head_inner,
d_ff_inner=d_ff_inner,
dropout=dropout, dropouta=dropouta)
for _ in range(n_layers)])
# tie weights
self.output_projection = nn.Linear(d_model, num_embeddings)
self.output_projection.weight = self.word_embs.embedding.weight
self.loss_fn = nn.CrossEntropyLoss()
self.seq_len, self.mem_len = seq_len, mem_len
# u and v are global parameters: maybe changing these to per-head parameters
# might help performance?
self.u, self.v = (nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)),
nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)))
def init_memory(self, device=torch.device("cpu")): # -> torch.FloatTensor:
return [torch.empty(0, dtype=torch.float).to(device) for _ in range(self.n_layers + 1)]
def update_memory(self,
previous_memory, #: List[torch.FloatTensor],
hidden_states, #: List[torch.FloatTensor],
):
assert len(hidden_states) == len(previous_memory)
mem_len, seq_len = previous_memory[0].size(0), hidden_states[0].size(0)
# For the updated memory, we use the most recent `self.mem_len`
# states, including the previous memory
# In other words, if `seq_len` < `self.mem_len` some of the previous memory
# will carry over to the next memory
with torch.no_grad():
new_memory = []
end_idx = mem_len + seq_len
beg_idx = max(0, end_idx - self.mem_len)
for m, h in zip(previous_memory, hidden_states):
cat = torch.cat([m, h], dim=0) # (mem_len + seq_len, bs, d)
new_memory.append(cat[beg_idx:end_idx].detach()) # (self.mem_len, bs, d)
return new_memory
def reset_length(self, seq_len, ext_len, mem_len):
self.seq_len = seq_len
self.mem_len = mem_len
def forward(self, idxs, #: torch.LongTensor, # (cs, bs)
target, #: torch.LongTensor, # (cs, bs)
memory=None, #: Optional[List[torch.FloatTensor]] = None,
): # -> Dict[str, torch.Tensor]:
if memory is None:
memory = self.init_memory(idxs.device)
assert len(memory) == len(self.layers) + 1
cur_seq, bs = idxs.size()
prev_seq = memory[0].size(0)
# Construct attention mask
dec_attn_mask = torch.triu(
torch.ones((cur_seq, cur_seq + prev_seq)),
diagonal=1 + prev_seq,
).bool()[..., None].to(idxs.device)
word_embs = self.drop(self.word_embs(idxs))
pos_idxs = torch.arange(cur_seq + prev_seq - 1, -1, -1.0, dtype=torch.float).to(word_embs.device)
pos_embs = self.drop(self.pos_embs(pos_idxs))
# Main part of forward pass
hidden_states = [word_embs]
layer_out = word_embs
for mem, layer in zip(memory, self.layers):
layer_out = layer(layer_out, pos_embs, self.u, self.v,
mask=dec_attn_mask, mems=mem)
hidden_states.append(layer_out)
logits = self.output_projection(self.drop(layer_out))
loss = self.loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
# Update memory
# Ensure the memory is treated as a constant
# and we do not back propagate through them
new_memory = self.update_memory(memory, hidden_states)
return {"loss": loss, "logits": logits, "memory": new_memory}
transformer = TransformerXL(1000, 4, 3, 32, 17, 71, mem_len=5).to(device)
idxs = torch.randint(1000, (5, 9)).to(device)
tgts = torch.randint(1000, (5, 9)).to(device)
transformer(idxs, tgts)