Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combine_bidir causes batch to be mixed up #138

Open
mingruimingrui opened this issue Apr 5, 2020 · 0 comments
Open

combine_bidir causes batch to be mixed up #138

mingruimingrui opened this issue Apr 5, 2020 · 0 comments

Comments

@mingruimingrui
Copy link

combine_bidir is a function in source/embed.py#L230.
It's used to concatenate the forward and backward hidden tensors from a bidirectional LSTM.

def combine_bidir(outs):
    return torch.cat([
        torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units)
        for i in range(self.num_layers)
    ], dim=0)

Here outs is a tensor of the shape [num_dir * num_layers, bsz, hidden_size].
The goal is to combine the tensor to the form [num_layers, bsz, num_dir * hidden_size].

The error is clear as the inner concatenate function should join the tensors on the final dimension instead of the first. Significantly current version of the implementation would mix up tensor values between different entries in the same batch. A quick fix would be to change dim=0 to dim=-1.

def combine_bidir(outs):
    return torch.cat([
        torch.cat([outs[2 * i], outs[2 * i + 1]], dim=-1).view(1, bsz, self.output_units)
        for i in range(self.num_layers)
    ], dim=0)

However, the code is still rather convoluted and includes one too many for loop which hampers the readability of the code. I suggest using purely reshape and transpose operations for this task.

def combine_bidir(outs):
    # [num_layers * num_dir, bsz, hidden_size]
    #   -> [num_layers, num_dir, bsz, hidden_size]
    #   -> [num_layers, bsz, num_dir, hidden_size]
    #   -> [num_layers, bsz, num_dir * hidden_size]
    outs = outs.reshape(self.num_layers, 2, bsz, self.hidden_size)
    outs = outs.transpose(1, 2)
    return outs.reshape(self.num_layers, bsz, self.output_units)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant