Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion about the batch in model.py(I think batch should be the second dimension.) #15

Open
WenXiuxiu opened this issue May 22, 2019 · 3 comments

Comments

@WenXiuxiu
Copy link

I read the code very carefully.
I am confused about the 65th line of model.py.
I think the second dimension of "hidden" is batch, not the first one.
Even though the encoder have been set with "batch_first=True”, the output will be have "batch" in the first dimension, but the hidden state is not.
I have test this on my own computer.

Of course, the code can run without problems. I just feel confused since the two dimensions are mixed up in the code.
Is there any one can help me with it?

@Lijiachen1018
Copy link

I found the same issue.
A simple example:

> t = torch.Tensor([[[1,2,3,4],[5,6,7,8],[9,0,1,2]],[[2,2,3,4],[5,6,7,8],[9,0,1,2]]])

> t
tensor([[[1., 2., 3., 4.],
         [5., 6., 7., 8.],
         [9., 0., 1., 2.]],

        [[2., 2., 3., 4.],
         [5., 6., 7., 8.],
         [9., 0., 1., 2.]]])

> t.shape
torch.Size([2, 3, 4]) 
# num_layers * 2 (if bidirectional), batch_size, max sequence lenth

> t.view(3,4*2)
tensor([[1., 2., 3., 4., 5., 6., 7., 8.],
        [9., 0., 1., 2., 2., 2., 3., 4.],
        [5., 6., 7., 8., 9., 0., 1., 2.]])

> torch.cat([t[i] for i in range(2)],dim=1)
tensor([[1., 2., 3., 4., 2., 2., 3., 4.],
        [5., 6., 7., 8., 5., 6., 7., 8.],
        [9., 0., 1., 2., 9., 0., 1., 2.]])

@YGJYG-qzq
Copy link

I agree with you. I found that, too.

@YGJYG-qzq
Copy link

YGJYG-qzq commented Oct 10, 2023

I fix this issue by using the function 'permute' like the following:

import torch

factor = 2
bs = 4
hs = 3
eh = torch.randn(factor, bs, hs)
eh_p = eh.permute(1,0,2)

#True
eh_t = eh_p.reshape(bs, factor*hs)

#False
eh_t2 = eh.reshape(bs, factor*hs)

Here I replaced 'view' by 'reshape' as suggested in the error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants