Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SelfAttention bug on Scores * V #165

Open
huberemanuel opened this issue Oct 14, 2023 · 1 comment
Open

SelfAttention bug on Scores * V #165

huberemanuel opened this issue Oct 14, 2023 · 1 comment

Comments

@huberemanuel
Copy link

Hey Aladdin, thanks for your tutorials!

I've been implementing the Transformer architecture and learning about einsum. Following your implementation (einsum) against one without einsum I found differences in the final result. Here is the code for reproducibility:

b, s, h, d = 2, 2, 2, 2
q = torch.randn((b, s, h, d))
k = torch.randn((b, s, h, d))
v = torch.randn((b, s, h, d))
q_mod = q.permute(0, 2, 1, 3) # [b, h, s, d]
k_mod = k.permute(0, 2, 3, 1) # [b, h, d, s]
classic_scores = torch.matmul(q_mod, k_mod)
classic_scores = torch.softmax(classic_scores / (d ** (1/2)), dim=3)
v_mod = v.permute(0, 2, 1, 3)
classic_att = torch.matmul(classic_scores, v_mod).reshape(b, s, h * d)

einstein_scores = torch.einsum("bqhd,bkhd->bhqk", q, k)
einstein_scores = torch.softmax(einstein_scores / (d ** (1/2)), dim=3)
einstein_att = torch.einsum("bhql,blhd->bqhd", einstein_scores, v).reshape(b, s, h * d)

assert torch.all(classic_scores == einstein_scores), "Scores doesn't match"
assert torch.all(classic_att == einstein_att), "Attention doesn't match"

The attention scores match perfectly, but the final attention score doesn't match. With my inputs, here is the result:

>>> print(classic_att)
tensor([[[ 1.1246,  0.1376,  1.2368, -0.6316],
         [-2.1842, -0.0181, -2.2082, -0.0023]],

        [[ 0.5911,  0.2132, -0.1727,  0.8552],
         [ 0.2701,  0.0846,  0.2370,  0.1205]]])
>>> print(einstein_att)
tensor([[[ 1.1246,  0.1376, -2.1842, -0.0181],
         [ 1.2368, -0.6316, -2.2082, -0.0023]],

        [[ 0.5911,  0.2132,  0.2701,  0.0846],
         [-0.1727,  0.8552,  0.2370,  0.1205]]])

It seems that the values aren't off, they are just transposed? I'm a newbie with einsum, and I couldn't figure it out. Hope someone can found the solution for this :)

@Tinghao-NTU
Copy link

I believe it's a bug, bro. You may check this link to find out the correct method for calculating the selft attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants