-
Notifications
You must be signed in to change notification settings - Fork 197
/
LongNet.py
68 lines (58 loc) · 1.99 KB
/
LongNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
from torchscale.architecture.decoder import Decoder, DecoderLayer
from torchscale.architecture.encoder import Encoder, EncoderLayer
from torchscale.component.dilated_attention import DilatedAttention
from fairscale.nn import checkpoint_wrapper, wrap
class LongNetDecoderLayer(DecoderLayer):
def build_self_attention(self, embed_dim, args):
return DilatedAttention(
args,
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
encoder_decoder_attention=False,
subln=args.subln,
)
class LongNetDecoder(Decoder):
def build_decoder_layer(
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
):
layer = LongNetDecoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer
class LongNetEncoderLayer(EncoderLayer):
def build_self_attention(self, embed_dim, args):
return DilatedAttention(
args,
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
encoder_decoder_attention=False,
subln=args.subln,
)
class LongNetEncoder(Encoder):
def build_encoder_layer(
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
):
layer = LongNetEncoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer