Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LLaVA #273

Open
briankariuki opened this issue Nov 8, 2023 · 10 comments
Open

Support LLaVA #273

briankariuki opened this issue Nov 8, 2023 · 10 comments
Labels
kind:feature New feature or request note:discussion Details up for discussion

Comments

@briankariuki
Copy link

I'm working on adding LLaVA to bumblebee as a learning exercise.

I need some guidance on a few things:

  1. From the official implementation of LLaVA as seen here , they are using ClipVisionModel from the huggingface transformers package to extract image features. Should I go ahead and reimplement this or just use the existing ClipVisionModel implementation already in bumblebee?
  2. In the implementations there's a params_mapping section. for example for LLaMA here. How do I go about identifying the layers of the model and what they map to in the Axon model?
  3. I would also require some guidance on implementing the core logic of the model.

The transformers package has not added support for LLaVA but there's an ongoing PR that can be found here but has not been merged yet.

Thanks.

@jonatanklosko
Copy link
Member

Hey @briankariuki! It looks like Llava is a composite model, so it will likely be closest to Clip, which is composed out of ClipVision and ClipText. Consequently, the implementation can be broken down, you can implement and test each individual model first.

  1. If they go with separate model in hf/transformers, then yeah we should have LlavaVision too.
  2. The PyTorch implementation uses a hierarchy of modules. For example, if we take the top-level class LlamaForCausalLM it defines self.model -> LlamaModel -> self.embed_tokens -> nn.Embedding (built-in layer). The PyTorch parameter names are inferred from the instance variables, so in that case the embedding layer is "model.embed_tokens". Assuming that LlavaText is similar to LLama, it may be a good starting point to copy the mapping from there.
  3. We have a transformer abstraction, which usually takes care of most of the implementation, such as here. Sometimes a model introduces a specific tweak, in which case we add an option to our generic implementation. If you have any specific questions let us know.

The transformers package has not added support for LLaVA but there's an ongoing PR that can be found here but has not been merged yet.

Once it is officially implemented, they will probably update the HF repo, or have a separate one that reflects that implementation. It may be worth waiting for the HF PR to be finalized to see what decisions they make, but we can also prototype sooner and sync once they merge :)

@briankariuki
Copy link
Author

Hey @jonatanklosko Thanks for the explanation above.

I've been able to implement a few things: One is the LlavaVision module that is similar to ClipVision. You can check that here.

I've also implemented LlavaText by following the implementation for LLama here and created the multimodal class for Llava here.

One problem I'm facing is how to convert the outputs of the vision model so that I can pass them to the text model. From the official implementation and the hf pull request there's a function that prepares the inputs for multimodal. I'm not sure how I would go about implementing that in bumblebee. You can find the implementation of the function here

Thanks.

@jonatanklosko
Copy link
Member

there's a function that prepares the inputs for multimodal

By a brief look this function processes the output of the vision model before we feed it into the text model and it actually uses some NN layers self.text_model.mm_projector. So in our case we would have some layers on top of the vision model output, combination of Axon layers and Axon.nx with custom code. It is a bit concerning that there are many calls to self.get_model().embed_tokens (which is an embedding layer) in a while loop, which may be tricky to replicate. This function is quite elaborate though and maybe not all parts are relevant for inference, I can't really tell without diving deeper into the model.

@briankariuki
Copy link
Author

briankariuki commented Nov 28, 2023

Thanks @jonatanklosko . Did you get a chance to look at the LLava model and code?

@jonatanklosko
Copy link
Member

It looks like the upstream PR moved to this one and is closer to crossing the finish line. Looking again I think this part is going to be really challenging unfortunately. In a way the implementation is really stitching the models, it embeds the image with vision model, and it embeds the text with a specific layer from the text model, then it combines these embeddings (separately for each batch entry) and passes through the text model.

@jkbbwr
Copy link

jkbbwr commented Jan 24, 2024

Any ideas or progress on this? Run up against it again today and wondering if there is anything I can do to help push this over

@briankariuki
Copy link
Author

Any ideas or progress on this? Run up against it again today and wondering if there is anything I can do to help push this over

Hello. I got stuck on how to implement the projector part that extracts the image features and embeds them into the LLM as tokens

@briankariuki
Copy link
Author

Any ideas or progress on this? Run up against it again today and wondering if there is anything I can do to help push this over

Hello. I got stuck on how to implement the projector part that extracts the image features and embeds them into the LLM as tokens

I was able to implement LLavaVision and LLavaText, which are very similar to ClipVision and LlamaText. The piece missing is the multimodal projector.

@jonatanklosko jonatanklosko added kind:feature New feature or request note:discussion Details up for discussion labels Feb 21, 2024
@jonatanklosko jonatanklosko changed the title Add LLaVA Support LLaVA Feb 21, 2024
@wadestuart
Copy link

It looks like llama.cpp ran into some of the same types of issues pulling it in to their interfaces -- here is a pr that shows how they are working through them ggerganov/llama.cpp#5267 I am still trying to map back to bumblebee to see if there is any parallels that can be had.

@seanmor5
Copy link
Contributor

seanmor5 commented Mar 1, 2024

The multimodal projector should just be a FFN between the image and LLM, I can take a look at this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature New feature or request note:discussion Details up for discussion
Projects
None yet
Development

No branches or pull requests

5 participants