/
hourglass.py
312 lines (273 loc) · 12.5 KB
/
hourglass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# coding=utf-8
# Copyright 2024 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hourglass - a hierarchical Transformer language model."""
import trax.layers as tl
from trax.layers.research.rel_attention import get_rel_att_inputs
from trax.layers.research.rel_attention import RelativeAttentionWrapper
from trax.layers.research.resampling import AttentionResampling
from trax.layers.research.resampling import AveragePooling
from trax.layers.research.resampling import FeedForwardBlock
from trax.layers.research.resampling import LinearUpsampling
from trax.models.research.configurable_transformer import ApplyAttentionLayer
def _RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation,
context_bias_layer, location_bias_layer,
total_pooling):
"""Returns a list of layers.
The layers implement a Transformer decoder block with relative attention
parametrization.
The input to the block is a pair, (activations, mask), where the mask was
created from the original source tokens to prevent attending to the padding
part of the input.
Args:
attention_type: attention type.
d_model: Final dimension of tensors at most points in the model, including
the initial embedding output.
d_ff: Size of special dense layer in the feed-forward part of each block.
n_heads: Number of attention heads.
dropout: Stochastic rate (probability) for dropping an activation value when
applying dropout within a block.
dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
way to save memory and apply consistent masks to activation vectors at
different sequence positions.
mode: If `'train'`, each block will include dropout; else, it will pass all
values through unaltered.
ff_activation: Type of activation function at the end of each block; must be
an activation-type subclass of `Layer`.
context_bias_layer: context bias layer.
location_bias_layer: location bias layer.
total_pooling: The combined pool size of previously used funnel blocks.
Returns:
A list of layers that maps (activations, att_vecs, mask) to
(activations, att_vecs, mask).
"""
if attention_type == RelativeAttentionWrapper:
attention = RelativeAttentionWrapper(
d_model,
n_heads,
dropout,
mode=mode,
context_bias_layer=context_bias_layer,
location_bias_layer=location_bias_layer,
total_pooling=total_pooling)
else:
attention = ApplyAttentionLayer(
attention_type,
d_model,
n_heads,
d_model // n_heads,
d_model // n_heads,
causal=True,
masked=False,
attention_dropout=dropout,
output_dropout=dropout,
attention_chunk_size=0, # Disables tl.Chunk in ApplyAttentionLayer.
mode=mode,
)
feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes,
mode, ff_activation)
def _Dropout():
return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
return [
tl.Residual( # vecs
tl.LayerNorm(),
attention,
_Dropout(),
), # vecs
tl.Residual(
tl.LayerNorm(),
feed_forward,
_Dropout(),
), # vecs
]
def _parse_hierarchy(hierarchy_str): # pylint: disable = invalid-name
"""Parse hierarchy for Hourglass definition."""
levels = hierarchy_str.split(' ')
if levels != levels[::-1]:
raise ValueError('Hierarchy is not a palindrome')
layer_level_pairs = [(x.split('@')) for x in levels[:1 + (len(levels) // 2)]]
hierarchy_n_layers = [int(x[0]) for x in layer_level_pairs]
total_sf_per_level = [int(x[1]) for x in layer_level_pairs]
hierarchy_shorten_factors = []
for current_sf, prev_sf in zip(total_sf_per_level,
[1] + total_sf_per_level[:-1]):
if current_sf % prev_sf != 0:
raise ValueError(
f'Hierarchy not divisible by previous level: {current_sf}, {prev_sf}')
hierarchy_shorten_factors.append(current_sf // prev_sf)
return hierarchy_n_layers, hierarchy_shorten_factors
def HourglassLM(vocab_size,
d_model=512,
d_ff=2048,
vanilla_layers=(1, 1),
hierarchy='6@3',
n_heads=8,
dropout=0.1,
dropout_shared_axes=None,
mode='train',
ff_activation=tl.FastGelu,
vanilla_attn_type=RelativeAttentionWrapper,
middle_attn_type=RelativeAttentionWrapper,
downsampling_fn=AttentionResampling,
upsampling_fn=AttentionResampling,
attention_downsampling_fn=AveragePooling,
attention_upsampling_fn=LinearUpsampling):
"""Returns a hierarchical Transformer language model.
This model performs autoregressive language modeling:
- input: rank 2 tensor representing a batch of text strings via token IDs
plus padding markers; shape is (batch_size, sequence_length). The tensor
elements are integers in `range(vocab_size)`, and `0` values mark padding
positions.
- output: rank 3 tensor representing a batch of log-probability
distributions for each sequence position over possible token IDs;
shape is (batch_size, sequence_length, `vocab_size`).
This model uses only the decoder part of the overall Transformer.
Args:
vocab_size: Input vocabulary size -- each element of the input tensor should
be an integer in `range(vocab_size)`. These integers typically represent
token IDs from a vocabulary-based tokenizer.
d_model: Final dimension of tensors at most points in the model, including
the initial embedding output.
d_ff: Size of special dense layer in the feed-forward part of each encoder
block.
vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level
Transformer decoder layers before and after shortening.
hierarchy: string - shortening hierarchy, as described in the paper.
Hierarchy levels must form a palindrome, e.g. '1@2 2@6 1@2'.
n_heads: Number of attention heads.
dropout: Stochastic rate (probability) for dropping an activation value when
applying dropout within an encoder block.
dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
way to save memory and apply consistent masks to activation vectors at
different sequence positions.
mode: str: 'train' or 'eval'.
ff_activation: Type of activation function at the end of each encoder block;
must be an activation-type subclass of `Layer`.
vanilla_attn_type: class: attention class such as SelfAttention to use in
the layers before and after shortening (vanilla layers).
middle_attn_type: class: attention class to use in the middle layers (these
operating on the shortened sequence).
downsampling_fn: function that takes full token-level vectors of length `l`
and transforms them into `l` / `k` vectors, where `k` denotes
`shorten_factor` parameter.
upsampling_fn: function that takes shortened representations of a sequence,
consisting of `l` / `k` vectors and transforms them into full token-level
representations of length `l`.
attention_downsampling_fn: Downsampling function that transforms token-level
vectors into query vectors with reduced length. Necessary only when
AttentionResampling is used as `downsampling_fn`.
attention_upsampling_fn: Upsampling function for AttentionResampling. Valid
only when AttentionResampling is used as a `upsampling_fn`.
Returns:
A Transformer language model as a layer that maps from a tensor of tokens
to activations over a vocab set.
"""
assert mode != 'predict' # For now, 'predict' mode is unsupported.
hierarchy_n_layers, hierarchy_shorten_factors = _parse_hierarchy(hierarchy)
token_encoder = [
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
]
context_bias_layer, location_bias_layer = get_rel_att_inputs(d_model, n_heads)
n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers
def create_decoder_blocks(n_layers, total_pooling, # pylint: disable = invalid-name
attention_type):
decoder_blocks = [
# pylint: disable=g-complex-comprehension
_RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation,
context_bias_layer, location_bias_layer,
total_pooling) for _ in range(n_layers)
]
return decoder_blocks + [tl.LayerNorm()]
def create_hourglass_valley(rest_shorten_factors, rest_n_funnel_blocks, # pylint: disable = invalid-name
current_total_pooling):
assert rest_shorten_factors
assert len(rest_shorten_factors) == len(rest_n_funnel_blocks)
current_sf = rest_shorten_factors[0]
current_n_layers = rest_n_funnel_blocks[0]
shortening_layer = downsampling_fn(
current_sf,
d_model,
is_upsampling=False,
d_ff=d_ff,
n_heads=n_heads,
dropout=dropout,
dropout_shared_axes=dropout_shared_axes,
mode=mode,
ff_activation=ff_activation,
context_bias_layer=context_bias_layer,
location_bias_layer=location_bias_layer,
total_pooling=current_total_pooling,
resampling_fn=attention_downsampling_fn)
upsampling_layer = upsampling_fn(
current_sf,
d_model=d_model,
is_upsampling=True,
d_ff=d_ff,
n_heads=n_heads,
dropout=dropout,
dropout_shared_axes=dropout_shared_axes,
mode=mode,
ff_activation=ff_activation,
context_bias_layer=context_bias_layer,
location_bias_layer=location_bias_layer,
total_pooling=current_total_pooling,
resampling_fn=attention_upsampling_fn)
if len(rest_shorten_factors) > 1: # we need to go deeper again
pre_stage_blocks = create_decoder_blocks(
current_n_layers, current_total_pooling * current_sf,
middle_attn_type)
post_stage_blocks = create_decoder_blocks(
current_n_layers, current_total_pooling * current_sf,
middle_attn_type)
return [
tl.Dup(),
tl.ShiftRight(current_sf - 1, mode=mode), shortening_layer,
pre_stage_blocks, *create_hourglass_valley(
rest_shorten_factors[1:], rest_n_funnel_blocks[1:],
current_total_pooling * current_sf), post_stage_blocks,
upsampling_layer,
tl.LayerNorm(),
tl.Add()
]
else:
blocks = create_decoder_blocks(current_n_layers,
current_total_pooling * current_sf,
middle_attn_type)
return [
tl.Dup(),
tl.ShiftRight(current_sf - 1), shortening_layer, blocks,
upsampling_layer,
tl.LayerNorm(),
tl.Add()
]
pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, 1,
vanilla_attn_type)
post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, 1,
vanilla_attn_type)
valley = create_hourglass_valley(hierarchy_shorten_factors,
hierarchy_n_layers, 1)
# Assemble and return the model.
return tl.Serial( # tokens (or chunked tuple of tokens)
tl.ShiftRight(mode=mode), # toks
token_encoder, # vecs
pre_decoder_blocks, # vecs
valley, # shortened vecs
post_decoder_blocks, # vecs
tl.Dense(vocab_size), # vecs
)