Skip to content

Commit

Permalink
start using a new linear attention from stanford
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 18, 2024
1 parent 434f36e commit 9f49074
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 45 deletions.
11 changes: 5 additions & 6 deletions README.md
Expand Up @@ -155,11 +155,10 @@ with trainer.trackers(project_name = 'magvit2', run_name = 'baseline'):
```

```bibtex
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}
```
57 changes: 19 additions & 38 deletions magvit2_pytorch/magvit2_pytorch.py
Expand Up @@ -28,6 +28,8 @@

from gateloop_transformer import SimpleGateLoopLayer

from taylor_series_linear_attention import TaylorSeriesLinearAttn

from kornia.filters import filter3d

import pickle
Expand Down Expand Up @@ -393,10 +395,8 @@ def __init__(
*,
dim,
dim_cond: Optional[int] = None,
dim_head = 32,
dim_head = 8,
heads = 8,
scale = 8,
flash = False,
dropout = 0.
):
super().__init__()
Expand All @@ -409,23 +409,10 @@ def __init__(
else:
self.norm = RMSNorm(dim)

self.to_qkv = Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
)

self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

self.attend = Attend(
scale = scale,
causal = False,
dropout = dropout,
flash = flash
)

self.to_out = Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim)
self.attn = TaylorSeriesLinearAttn(
dim = dim,
dim_head = dim_head,
heads = heads
)

def forward(
Expand All @@ -437,14 +424,7 @@ def forward(

x = self.norm(x, **maybe_cond_kwargs)

q, k, v = self.to_qkv(x)

q, k = map(l2norm, (q, k))
q = q * self.temperature.exp()

out = self.attend(q, k, v)

return self.to_out(out)
return self.attn(x)

class LinearSpaceAttention(LinearAttention):
def forward(self, x, *args, **kwargs):
Expand Down Expand Up @@ -613,6 +593,8 @@ def __init__(
max_dim = 512,
attn_heads = 8,
attn_dim_head = 32,
linear_attn_dim_head = 8,
linear_attn_heads = 16,
ff_mult = 4,
antialiased_downsample = False
):
Expand Down Expand Up @@ -647,9 +629,8 @@ def __init__(
attn_block = Sequential(
Residual(LinearSpaceAttention(
dim = out_chan,
heads = attn_heads,
dim_head = attn_dim_head,
flash = False
heads = linear_attn_heads,
dim_head = linear_attn_dim_head
)),
Residual(FeedForward(
dim = out_chan,
Expand Down Expand Up @@ -1090,6 +1071,8 @@ def __init__(
attn_dim_head = 32,
attn_heads = 8,
attn_dropout = 0.,
linear_attn_dim_head = 8,
linear_attn_heads = 16,
vgg: Optional[Module] = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
perceptual_loss_weight = 1e-1,
Expand Down Expand Up @@ -1209,21 +1192,19 @@ def __init__(
)

elif layer_type == 'linear_attend_space':
attn_kwargs = dict(
linear_attn_kwargs = dict(
dim = dim,
dim_head = attn_dim_head,
heads = attn_heads,
dropout = attn_dropout,
flash = flash_attn
dim_head = linear_attn_dim_head,
heads = linear_attn_heads
)

encoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(LinearSpaceAttention(**linear_attn_kwargs)),
Residual(FeedForward(dim))
)

decoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(LinearSpaceAttention(**linear_attn_kwargs)),
Residual(FeedForward(dim))
)

Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
@@ -1 +1 @@
__version__ = '0.3.4'
__version__ = '0.4.0'
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -32,6 +32,7 @@
'pytorch-custom-utils>=0.0.9',
'numpy',
'vector-quantize-pytorch>=1.11.8',
'taylor-series-linear-attention>=0.1.5',
'torch',
'torchvision',
'x-transformers'
Expand Down

0 comments on commit 9f49074

Please sign in to comment.