Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reshape may not match #36

Open
pengxingang opened this issue Mar 12, 2020 · 0 comments
Open

reshape may not match #36

pengxingang opened this issue Mar 12, 2020 · 0 comments

Comments

@pengxingang
Copy link

pengxingang commented Mar 12, 2020

Hi, thanks a lot for your code. It seems that I find a bug.

In the MultiHeadAttention layer, the reshape1 function

x = tf.reshape(x, [s[0], s[1], n_head, s[2]//n_head])
x = tf.transpose(x, [2, 0, 1, 3]) 
x = tf.reshape(x, [-1, s[1], s[2]//n_head])

The transpose puts the head axis before the batch axis. After reshaping, the first axis should be like this (suppose N samples and only 2 heads):

sample_0_head_0
sample_1_head_0
sample_2_head_0
...
sample_N-1_head_0
sample_0_head_1
sample_1_head_1
sample_2_head_1
...
sample_N-1_head_1

But the repeats of mask:

mask = Lambda(lambda x:K.repeat_elements(x, n_head, 0))(mask)

will return mask like this:

mask_0,
mask_0,
mask_1,
mask_1,
...
mask_N,
mask_N,

(find the useage of repeat_elements here)

However, actually we want mask to be like this:

mask_0,
mask_1,
...
mask_N-1,
mask_0,
mask_1,
...
mask_N-1

So I think the reshape function reshape1 should change x = tf.transpose(x, [2, 0, 1, 3]) into x = tf.transpose(x, [0, 2, 1, 3]). And so does the reshape2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant