Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tracker: generate composability refactor #30810

Open
13 tasks
gante opened this issue May 14, 2024 · 4 comments
Open
13 tasks

tracker: generate composability refactor #30810

gante opened this issue May 14, 2024 · 4 comments
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@gante
Copy link
Member

gante commented May 14, 2024

generate + composability = more use cases with minimal rewrites

As I write this issue, generate is mostly a sequential monolith. Many internal blocks were carved into functions over the last two years, but navigating there as a beginner is still messy. It is also very challenging to adapt generate to different tasks and/or modalities, forcing us to overwrite the entire generate function (e.g. RAG, MusicGen). All these aspects make using, documenting, maintaining, and testing generate a challenge.

This issue is a tracker for the refactor of generate, where we aim to build the structure outlined in this board. Key ideas for this refactor:
👉 All models can use the base generate API
👉 Reduce if/else blocks
👉 Reduce the barriers to entry for new decoding methods/modalities/use cases
👉 Reduce per-model overwrites when possible
👉 Add unit tests
👉 Add documentation regarding the structure of generate

Tasks

  • 1. Isolate prefill into a separate function, pulling it from the decoding methods. Note that
    • a) prefill is done excluding the latest token (input_ids[:, -1:]), so we don't compute variables regarding the latest token twice;
    • b) prefill only runs when use_cache=True and cache length < input length - 1;
    • c) _expand_inputs_for_generation needs to be changed (it copied inputs before prefill, we will need to copy prefill outputs)
  • 2. (depends on 1.) Separate generate on the 5 stages described in the diagram, passing around the data structures described therein
  • 3. (depends on 1.) Streaming 2.0
    • a) Add option to yield/yield from instead of return
    • b) Deprecate the old streamer classes;
    • c) Add a new class to print the stream into the screen. For beam methods, build a class that prints up to the point where all beams agree with each other.
    • d) thoroughly document and communicate this feature
    • e) enable streaming into the screen with pipeline
  • 4. (depends on 2.) Separate stage 1 into a set of functions as described in the diagram. Add unit tests.
  • 5. (depends on 2.) Separate stage 2 into a set of functions as described in the diagram. Add unit tests. Move the preparation of common model inputs here, such as position_ids.
  • 6. (depends on 2.) Separate stage 3 into a set of functions as described in the diagram. Add unit tests. Deprecate LogitsWarper in this step (it's a copy of LogitsProcessor)
  • 7. (depends on 2.) Separate stage 5 into a set of functions as described in the diagram. Add unit tests.
  • 8. Add a new document page walking through the structure of generate

[From this point onwards the tasks are only a sketch, need more detailed planning when we get there]

  • 9. Reduce if/elses through templates (e.g. LLMs have a certain default for prepare_inputs_for_generation, VLMs also have their special preprocessing steps, ...)
  • 10. Play around with caching of some blocks to determine whether it speeds up generation
  • 11. Rework prepare_inputs_for_generation ?
  • 12. Remove generate from models that have a custom implementation
  • (other tasks, TBD)
@amyeroberts amyeroberts added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label May 15, 2024
@amyeroberts
Copy link
Collaborator

Adding the WIP label, so you don't get pinged by the stale bot 🤖

@gante gante mentioned this issue May 16, 2024
26 tasks
@dmarx
Copy link

dmarx commented May 20, 2024

Could you elaborate on the "prefill" component? My impression is that this step converts the prompt into a KV cache, i.e. "pre-filling" the KV component for the tokens that are fixed. If that's correct, this component could probably serve double duty (as an exit point from the generate procedure) for users who just want logprobs/scores for the prompt.

@dmarx
Copy link

dmarx commented May 20, 2024

IMHO, upsteam of (part of?) the "generate outputs" step in the decoding loop should be a templated _prepare_outputs function whose job is just to attach output attributes to the appropriate output class. Design POC via #29545, concretely:
https://github.com/coreweave/transformers/blob/dmarx.output_streamer/src/transformers/generation/utils.py#L354-L378

@gante
Copy link
Member Author

gante commented May 21, 2024

@dmarx

Could you elaborate on the "prefill" component? My impression is that this step converts the prompt into a KV cache, i.e. "pre-filling" the KV component for the tokens that are fixed. If that's correct, this component could probably serve double duty (as an exit point from the generate procedure) for users who just want logprobs/scores for the prompt.

Correct, it can be made a public function with that additional purpose. The difference between the prefill for generate and obtaining the scores for the prompt is that in the former, we only want to keep the past KV. The different output needs suggest to me that a stand-alone public function is preferable to an alternate exit to generate :D Added this to the notes of the prefill stage in the diagram.

IMHO, upsteam of (part of?) the "generate outputs" step in the decoding loop should be a templated _prepare_outputs function whose job is just to attach output attributes to the appropriate output class. Design POC via #29545, concretely:
https://github.com/coreweave/transformers/blob/dmarx.output_streamer/src/transformers/generation/utils.py#L354-L378

That's a good idea, adding it to the diagram!

Thank you for the suggestions 💛

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

3 participants