Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed to convert convert llava-llama3 model to hf format #656

Open
Jason8Kang opened this issue May 7, 2024 · 9 comments
Open

failed to convert convert llava-llama3 model to hf format #656

Jason8Kang opened this issue May 7, 2024 · 9 comments
Assignees

Comments

@Jason8Kang
Copy link

Jason8Kang commented May 7, 2024

In the project: https://github.com/InternLM/xtuner/tree/main/xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336, it gives an examples how to convert llava-llama3 model to hf format:

python ./convert_xtuner_weights_to_hf.py --text_model_id ./iter_39620_xtuner --vision_model_id ./iter_39620_visual_encoder --projector_weight ./iter_39620_xtuner/projector/model.safetensors --save_path ./iter_39620_llava

I follow it in this way:

  1. I train llava-llama3 model with single gpu with llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_e1_gpu1_finetune.py, and get the pth named lama3_llava_pth;
  2. convert llama3_llava_pth to hugggingface format named llama3_llava_pth/hf;
  3. I merge the llm_adapter and original llama3 to new llama3 with xtuner merge tool, and create ./llama3_llava_pth/merge
  4. I try to convert format
    CUDA_VISIBLE_DEVICES=4 python ./xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336/convert_xtuner_weights_to_hf.py --text_model_id ./llama3_llava_pth/merge --vision_model_id ${vit} --projector_weight llama3_llava_pth/hf/projector --save_path ./llama3_llava_pth/LLava_format

however I get the error

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizerFast'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
  File "/Data_PHD/****/projects/xtuner/./xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336/convert_xtuner_weights_to_hf.py", line 140, in <module>
    main()
  File "/Data_PHD/****/projects/xtuner/./xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336/convert_xtuner_weights_to_hf.py", line 135, in main
    convert_to_hf(args.text_model_id, args.vision_model_id,
  File "/Data_PHD/****/projects/xtuner/./xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336/convert_xtuner_weights_to_hf.py", line 61, in convert_to_hf
    model = LlavaForConditionalGeneration(config)
  File "/Data_PHD/****/anaconda3_35/envs/XTUNER/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 244, in __init__
    self.multi_modal_projector = LlavaMultiModalProjector(config)
  File "/Data_PHD/****/anaconda3_35/envs/XTUNER/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 93, in __init__
    self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
  File "/Data_PHD/****/anaconda3_35/envs/XTUNER/lib/python3.10/site-packages/transformers/configuration_utils.py", line 263, in __getattribute__
    return super().__getattribute__(key)
AttributeError: 'CLIPConfig' object has no attribute 'hidden_size'`

is there any error in my steps, thank you for you answer!

@Jason8Kang Jason8Kang changed the title failed to convert convert llava-llama3 model to official llava format failed to convert convert llava-llama3 model to hf format May 7, 2024
@Jason8Kang Jason8Kang changed the title failed to convert convert llava-llama3 model to hf format failed to convert convert llava-llama3 model to official llava format May 7, 2024
@Jason8Kang Jason8Kang changed the title failed to convert convert llava-llama3 model to official llava format failed to convert convert llava-llama3 model to hf format May 7, 2024
@Jason8Kang Jason8Kang reopened this May 7, 2024
@LZHgrla
Copy link
Collaborator

LZHgrla commented May 8, 2024

@Jason8Kang Hi!
This PR #661 can solve your issue!

Meanwhile, please specific the .safetensors file for --projector_weight, like llama3_llava_pth/hf/projector/model.safetensors

@Jason8Kang
Copy link
Author

Jason8Kang commented May 8, 2024

@LZHgrla Thanks, it works.
I have a similar question in convert it to llava format. This is my code
python ./xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336/convert_xtuner_weights_to_llava.py --text_model_id ./llama3_llava_pth/merge --vision_model_id ${vit} --projector_weight llama3_llava_pth/hf/projector/model.safetensors --save_path ./llama3_llava_pth/LLava_format

This is the warning. when I load the LLava_format by huggingface, it will fail. there maybe similar error.
/Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.23.mlp.fc2.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to pass assign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.23.layer_norm2.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.23.layer_norm2.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.post_layernorm.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.post_layernorm.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=True to assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta '

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 8, 2024

@Jason8Kang
I think these warnings will not affect the usage, and we can ignore them. Right?

@Jason8Kang
Copy link
Author

I also think so at first. but when I load the LLava_format in this way.

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

model_path = "./llama3_llava_pth.bak/LLava_format"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path)
)

it will give warning.
/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.self_attn.v_proj.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to pass assign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.self_attn.v_proj.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.self_attn.q_proj.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.self_attn.q_proj.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.self_attn.out_proj.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.self_attn.out_proj.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.layer_norm1.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.layer_norm1.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.mlp.fc1.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.8.mlp.fc1.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=True to assign items in the state dictionary to their corresponding key in the module instead of copying them in place?)

but when I run inference, it will give error
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:01<00:00, 6.28it/s] Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Data_PHD/***/projects/LLaVA/llava/eval/run_llava.py", line 115, in eval_model output_ids = model.generate( File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/Data_PHD/***/projects/LLaVA/llava/model/language_model/llava_llama.py", line 125, in generate ) = self.prepare_inputs_labels_for_multimodal( File "/Data_PHD/***/projects/LLaVA/llava/model/llava_arch.py", line 202, in prepare_inputs_labels_for_multimodal image_features = self.encode_images(images) File "/Data_PHD/***/projects/LLaVA/llava/model/llava_arch.py", line 141, in encode_images image_features = self.get_model().get_vision_tower()(images) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, **kwargs) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/Data_PHD/***/projects/LLaVA/llava/model/multimodal_encoder/clip_encoder.py", line 54, in forward image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 160, in new_forward args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 290, in pre_forward return send_to_device(args, self.execution_device), send_to_device( File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/utils/operations.py", line 151, in send_to_device return honor_type( File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/utils/operations.py", line 83, in honor_type return type(obj)(generator) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/utils/operations.py", line 152, in <genexpr> tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor) File "/Data_PHD/***/anaconda3_35/envs/llava/lib/python3.10/site-packages/accelerate/utils/operations.py", line 167, in send_to_device return tensor.to(device, non_blocking=non_blocking) NotImplementedError: Cannot copy out of meta tensor; no data!

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 8, 2024

@ztfmars
Copy link

ztfmars commented May 8, 2024

@Jason8Kang hi,

the offical xturner/llava-llama3 can be used for lmdeploy and work well.
image

i get the same llava-llama3 huggingface transfer question , i just following your changed code This PR and solved the problem.

but trained llava-llama3 weight can't be used for lmdeploy pipline, and the error can be listed as following:

llama3-ft$ python pipeline_llava.py
[2024-05-08 12:43:31,652] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
2024-05-08 12:43:33,017 - lmdeploy - ERROR - AttributeError: 'LlavaConfig' object has no attribute 'hidden_size'
2024-05-08 12:43:33,017 - lmdeploy - ERROR - <Model> test failed!
("Checking failed with error 'LlavaConfig' object has no attribute 'hidden_size'", 'Please send issue to LMDeploy with error logs.')

image

and the lmdeploy code can be show as following:

from lmdeploy import pipeline, ChatTemplateConfig
from lmdeploy.vl import load_image

# from modelscope import snapshot_download, AutoModel, AutoTokenizer
# import os

### download
# modelscope下载
# model_dir = snapshot_download('xtuner/llava-llama-3-8b-hf')


#### llava-lama3-8b pipeline
### offical huggface
# pipe = pipeline('/home/fusionai/.cache/modelscope/hub/xtuner/llava-llama-3-8b-hf',
#                 chat_template_config=ChatTemplateConfig(model_name='llama3'))

### train huggingface
pipe = pipeline('/home/fusionai/project/internllm_demo/llama3/llama3-ft/train/llava_train_20240508_1/iter_1008_final_hf',
                chat_template_config=ChatTemplateConfig(model_name='llama3'))

image = load_image('/home/fusionai/project/internllm_demo/lldemploy_demo/tiger.jpeg')
response = pipe(('describe this image', image))
print(response.text)

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 8, 2024

@ztfmars
Please use convert_xtuner_weights_to_llava.py (not to hf) to get the official llava format weights. lmdeploy only supports the deployment for the official llava format weights.

@Jason8Kang
Copy link
Author

@Jason8Kang Try this script?

https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-hf/discussions/1

thanks you, it works in your script. but it still has the warning as follows. do you know the meaning of this warning, I'm not sure whether it matters?
UserWarning: for vision_model.encoder.layers.23.mlp.fc2.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to pass assign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.23.layer_norm2.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.encoder.layers.23.layer_norm2.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.post_layernorm.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=Trueto assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' /Data_PHD/***/anaconda3_35/envs/xtuner/lib/python3.10/site-packages/torch/nn/modules/module.py:2025: UserWarning: for vision_model.post_layernorm.bias: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to passassign=True to assign items in the state dictionary to their corresponding key in the module instead of copying them in place?) warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta '

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 9, 2024

@Jason8Kang
I also encountered this warning, but I tested this model on mmbench dev en, and its accuracy was normal. So, I did not delve into resolving this warning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants