Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAB Implementation diverges from Paper #8

Open
jlko opened this issue Nov 17, 2020 · 6 comments
Open

MAB Implementation diverges from Paper #8

jlko opened this issue Nov 17, 2020 · 6 comments

Comments

@jlko
Copy link

jlko commented Nov 17, 2020

Dear Juho,

is it possible that the implementation of the MAB diverges from the paper?

In more detail: The paper states

Multihead(Q,K,V;λ,ω)=concat(O_1,··· ,O_h)W_O
H = LayerNorm(X + Multihead(X, Y, Y ; ω))
MAB(X, Y ) = LayerNorm(H + rFF(H))

but the code does

A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)  # This is output of multihead
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
  • It seems that the matrix W_O is not being used in the code at all to mix the output of the different heads?

  • The skip connection Q_ + A.bmm(V_) also diverges from what's stated in the paper, given that Q_ is derived from Q which gets linearly transformed via Q = self.fc_q(Q) in the first line of forward() and is therefore no longer equal to the original query. (On second thought, this may be a necessary requirement, since the output of the MAB has different shape than the input shape. That means in this case, the paper is imprecise.)

Thanks a lot and best wishes
Jannik

@jingweiz
Copy link

Hi Jannik,
I think here fc_o (in the 5th line of the code you pasted) is the W_O in the paper, what do you think?

@jlko
Copy link
Author

jlko commented Nov 26, 2020

Dear Jingweiz,
thanks for your reply!
I would have identified fc_o with the rFF(H) of the MAB and not with W_O.

@juho-lee
Copy link
Owner

Hi, thanks for your interest!

  1. Multiplying W_0 after the concat and 2) multiplying W to the query to get Q and then split-attend-concat, in essence, makes a small difference (one of them is a restricted version of another). For the paper, I followed the description in the original transformer paper, and for the code, I chose the current form following the code available for original transfomer (also, it gives a cleaner code). But they don't make a big empirical difference.

@jlko
Copy link
Author

jlko commented Nov 26, 2020

Hey Juho Lee!
Thanks for your reply.

It makes sense that this does not give a big empirical difference. I just wanted to check if I missed something.

And LayerNorm(X + Multihead(X, Y, Y ; ω)) in the paper, should probably be something like LayerNorm(W_q X + Multihead(X, Y, Y ; ω)), correct?

@jingweiz
Copy link

Dear Jingweiz,
thanks for your reply!
I would have identified fc_o with the rFF(H) of the MAB and not with W_O.

Oh right exactly, I got messed up, thanks!

@npielawski
Copy link

I have a follow up question linked to this topic.

In the paper a row-wise FF block is used for the pooling, and unlike the rFF in the MAB, the rFF in the pooling doesn't have an activation function. Should the PMA rFF have an activation function or not?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants