Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAttention Module #169

Open
rachtibat opened this issue Oct 20, 2022 · 1 comment
Open

LinearAttention Module #169

rachtibat opened this issue Oct 20, 2022 · 1 comment
Labels
model compatibility Compatibility for new or variations of existing models

Comments

@rachtibat
Copy link
Contributor

rachtibat commented Oct 20, 2022

Hi Christopher,

hope you're fine and I'm really glad that the zennit community grows, congratulation!
With a growing community, more nn.Modules desire to be explained and that's why I'm writing this issue.
A student in our department tries to explain a LinearAttention module. (The implementation is below for reference).

It contains a series of
torch.einsum
and
torch.transpose
operations.

It uses the rearrange function of the einops library, a new syntax to write basic torch code like transpose, reshape etc.

I think, zennit should be able to analyse a series of reshaping and transposing operations. However, I am not completely sure.
I'd be glad, if you could give your opinion on analyzing such a linear attention module. If you don't know, that's also no problem (: Then, it's the beginning of a new research topic.

(And the softmax function is also a problem, but maybe Arras et. al has a solution to this which the student could implement... )

Best,
Reduan

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)
@chr5tphr
Copy link
Owner

chr5tphr commented Nov 2, 2022

Hey Reduan,

thank you for the issue!
You can have a look at this work, where they introduce LRP for Transformers (i.e. also attention heads).
I have talked to @tschnake before about bringing transformers to Zennit, which is still as WIP as it gets.

About the implementation details:

The rearrange operation is just a re-indexing, so the correct approach for it is already simply the gradient, so it is supported by Zennit.
The einsum is a linear operation, so it can be handled like a linear layer in LRP.
The softmax is a little tricky. In the work above they handle this by viewing the gating terms as constants.

In code, we may get away by requiring to use torch.nn.Softmax and implementing a Constant rule, which will have the gradient be set to zero, although I need to think a little more if this would work as intended.

Otherwise, we could also implement a canonizer (or a meta-rule) for the most popular library implementing attention layers.

@chr5tphr chr5tphr added the model compatibility Compatibility for new or variations of existing models label Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model compatibility Compatibility for new or variations of existing models
Projects
None yet
Development

No branches or pull requests

2 participants