Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable more flexible inputs to the TransformerDecoder #968

Open
marcinwazny opened this issue May 12, 2024 · 1 comment
Open

Enable more flexible inputs to the TransformerDecoder #968

marcinwazny opened this issue May 12, 2024 · 1 comment
Assignees

Comments

@marcinwazny
Copy link

Currently, the forward method of the TransformerDecoder class requires a tokens tensor of the shape [b, s] to be passed as an argument, which is then passed to self.tok_embeddings.

But the capabilities of transformers go far beyond working with text, and sometimes you want to use them with data that is more complex than sequences of integers.

Perhaps it would be worth relaxing the TransformerDecoder implementation to allow easier use of them in such cases?

Specifically, to allow the input data to be of any shape [b, s, ...], and to change the type of the tok_embeddings from nn.Embedding to any model that inherits from nn.Module and returns a tensor of the shape [b, s, d].

Alternatively do it like huggingface library, which allows inputs_embeds to be passed directly instead of inputs_ids.

@kartikayk
Copy link
Contributor

Thanks for opening this issue!

This is a good suggestion. We've been discussing redesigning the transformer decoder task to output more information (eg: hidden states from intermediate layers). I think making the embedding layer more generic can be part of this change. I'll put up an RFC some time next week and share here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants