Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba: use_cache is not passed through in prepare_inputs_for_generation #30849

Open
uwu-420 opened this issue May 16, 2024 · 2 comments · May be fixed by #31116
Open

Mamba: use_cache is not passed through in prepare_inputs_for_generation #30849

uwu-420 opened this issue May 16, 2024 · 2 comments · May be fixed by #31116

Comments

@uwu-420
Copy link

uwu-420 commented May 16, 2024

Hi :)

I think that use_cache is supposed to be passed through here as well:

def prepare_inputs_for_generation(
self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
):
# only last token for inputs_ids if the state is passed along.
if cache_params is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs["cache_params"] = cache_params
return model_inputs

I noticed when I wanted to get the cache when using model.generate, but it was not there although I set use_cache=True.

Edit: Just saw that GenerateDecoderOnlyOutput would have to be adjusted as well. It would need to contain cache_params similarly to past_key_values. I don't know if it's okay for you to bloat that even more.

return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)

Cheers and thanks for the great work!

@amyeroberts
Copy link
Collaborator

cc @gante

@gante
Copy link
Member

gante commented May 28, 2024

@zucchini-nlp could you have a look at this issue? 🤗

@zucchini-nlp zucchini-nlp linked a pull request May 29, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants