You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the forward method of the TransformerDecoder class requires a tokens tensor of the shape [b, s] to be passed as an argument, which is then passed to self.tok_embeddings.
But the capabilities of transformers go far beyond working with text, and sometimes you want to use them with data that is more complex than sequences of integers.
Perhaps it would be worth relaxing the TransformerDecoder implementation to allow easier use of them in such cases?
Specifically, to allow the input data to be of any shape [b, s, ...], and to change the type of the tok_embeddings from nn.Embedding to any model that inherits from nn.Module and returns a tensor of the shape [b, s, d].
Alternatively do it like huggingface library, which allows inputs_embeds to be passed directly instead of inputs_ids.
The text was updated successfully, but these errors were encountered:
This is a good suggestion. We've been discussing redesigning the transformer decoder task to output more information (eg: hidden states from intermediate layers). I think making the embedding layer more generic can be part of this change. I'll put up an RFC some time next week and share here.
Currently, the
forward
method of theTransformerDecoder
class requires a tokens tensor of the shape[b, s]
to be passed as an argument, which is then passed toself.tok_embeddings
.But the capabilities of transformers go far beyond working with text, and sometimes you want to use them with data that is more complex than sequences of integers.
Perhaps it would be worth relaxing the
TransformerDecoder
implementation to allow easier use of them in such cases?Specifically, to allow the input data to be of any shape
[b, s, ...]
, and to change the type of thetok_embeddings
fromnn.Embedding
to any model that inherits fromnn.Module
and returns a tensor of the shape[b, s, d]
.Alternatively do it like huggingface library, which allows
inputs_embeds
to be passed directly instead ofinputs_ids
.The text was updated successfully, but these errors were encountered: