Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to add cross attention aka encoder outputs to SRU++ ? #164

Open
hadaev8 opened this issue Feb 26, 2021 · 8 comments
Open
Labels

Comments

@hadaev8
Copy link

hadaev8 commented Feb 26, 2021

No description provided.

@taoleicn
Copy link
Contributor

taoleicn commented Feb 27, 2021

Hi @hadaev8

At the moment, we haven't implemented a SRU++ "decoder" in which there are both self attention and cross attention. There are two options you could choose:

  1. You can feed in memory to SRU++ and it will treat it as extra context to attend. In other words, you can do sth like:
# encoder
enc_output, enc_hidden, _ = SRUpp_encoder(enc_input, pad_mask=is_padding_mask_enc)

# decoder
memory = [output] * num_dec_layers  # a list of tensor of size (length, batch size, d)
dec_output, dec_hidden, _ = SRUpp_decoder(dec_input,
                                          pad_mask=is_padding_mask_dec,
                                          attn_mask=attention_mask,   # (dec_length, src_length + dec_length)
                                          memory=memory,
                                          memory_mask_pad=...)

Note we are assuming all input & hidden dimensions are d here.

  1. You can customize a SRUpp decoding layer. See here:
    https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L907-L911

@hadaev8
Copy link
Author

hadaev8 commented Feb 27, 2021

@taoleicn
The first option seems like the usual decoder it transformer.
It will attend to self outputs and to memory inputs, right?

Where can I find the transform_module definition?

@taolei87
Copy link
Contributor

@hadaev8 yes and no.

Yes in the sense that within each SRU++ layer, the layer will attend to both self outputs and the memory inputs.
No in the sense that in a transformer decoder, there are two attention sub-layers. One is used only for self attention, and another one is used only for cross attention. In option 1, what would happen is the memory tensor will first be concatenated with the self outputs from the previous layer, and then only one attention is applied. See https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L791-L793

Re: transform_module
definition:
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L90-L94

how SRUpp set transform_module as the attention sub-module:
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L1019-L1028
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L1046

forward method of SRUppCell:
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L907-L911

@hadaev8
Copy link
Author

hadaev8 commented Mar 1, 2021

@taolei87
Do I understood correctly what this expects one memory vector instead of sequence?

@taoleicn
Copy link
Contributor

taoleicn commented Mar 1, 2021

@hadaev8 i'm not sure i follow. can you elaborate more on your question?

@hadaev8
Copy link
Author

hadaev8 commented Mar 1, 2021

@taolei87
What is expected size of memory tensor?

@taoleicn
Copy link
Contributor

taoleicn commented Mar 1, 2021

It is a 3-dimensional tensor (memory_seq_len, batch_size, hidden_size). See an illustration below:

SRUppCell interface

SRUpp module takes a list of memory tensors (one for each sub-layer), and SRUppCell takes a single memory tensor.
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L1088
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L771

I updated the pseudo code in the previous reply for a correction.

@hadaev8
Copy link
Author

hadaev8 commented Mar 1, 2021

@taolei87
Now I got it.
I will try it as it is but have a feeling it's not a good idea to concat self and cross attentions under one softmax.
Any plans for adding more common cross attention?
Also, how it should work in inference?

Spotted this thing:
https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L158

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants