Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llava-llama3-8b llm+llm adapter merge error #662

Open
ztfmars opened this issue May 8, 2024 · 2 comments
Open

llava-llama3-8b llm+llm adapter merge error #662

ztfmars opened this issue May 8, 2024 · 2 comments

Comments

@ztfmars
Copy link

ztfmars commented May 8, 2024

i use llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py to fineture on my dataset, and want to get a llava-llama38b multimodal model on my datasets.
after training and pth -> hf,
i got llm adapter, visual encoder adapter ,project.
image

but i can't merge llm +llm adapter together and can'get the LLM weights as turial
https://github.com/InternLM/xtuner/tree/main/xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336

image

the error can be listed as following:

/llava_train_20240506$ xtuner convert merge /home/fusionai/.cache/modelscope/hub/LLM-Research/Meta-Llama-3-8B-                Instruct /home/fusionai/project/internllm_demo/llama3/llama3-ft/llava_train_20240506/iter_1000_hf/llm_adapter /home/fusionai/project/internllm_demo/llama3/llama3-ft/llava_train_                20240506/iter_1000_llava
[2024-05-08 09:51:48,946] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
[2024-05-08 09:51:53,816] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
Loading checkpoint shards:  75%|██████████████████████████████████████████████████████████████████████████████████████▎                            |                                             Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████|                                             Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████|                                             4/4 [00:05<00:00,  1.25s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
  File "/home/fusionai/project/internllm/xtuner/xtuner/tools/model_converters/merge.py", line 73, in <module>
    main()
  File "/home/fusionai/project/internllm/xtuner/xtuner/tools/model_converters/merge.py", line 56, in main
    model_unmerged = PeftModel.from_pretrained(
  File "/home/fusionai/anaconda3/envs/llama3/lib/python3.10/site-packages/peft/peft_model.py", line 324, in from_pretrained
    config = PEFT_TYPE_TO_CONFIG_MAPPING[
  File "/home/fusionai/anaconda3/envs/llama3/lib/python3.10/site-packages/peft/config.py", line 151, in from_pretrained
    return cls.from_peft_type(**kwargs)
  File "/home/fusionai/anaconda3/envs/llama3/lib/python3.10/site-packages/peft/config.py", line 118, in from_peft_type
    return config_cls(**kwargs)
TypeError: LoraConfig.__init__() got an unexpected keyword argument 'layer_replication'

additon description:

  • merge cmd
xtuner convert merge /home/fusionai/.cache/modelscope/hub/LLM-Research/Meta-Llama-3-8B-Instruct \
/home/fusionai/project/internllm_demo/llama3/llama3-ft/llava_train_20240506/iter_1000_hf/llm_adapter \
/home/fusionai/project/internllm_demo/llama3/llama3-ft/llava_train_20240506/iter_1000_llava
  • training configs
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig, CLIPImageProcessor,
                          CLIPVisionModel)

from xtuner.dataset import LLaVADataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
from xtuner.dataset.samplers import LengthGroupedSampler
from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
from xtuner.engine.runner import TrainLoop
from xtuner.model import LLaVAModel
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
llm_name_or_path = '/home/fusionai/.cache/modelscope/hub/LLM-Research/Meta-Llama-3-8B-Instruct'
visual_encoder_name_or_path = '/home/fusionai/.cache/modelscope/hub/AI-ModelScope/clip-vit-large-patch14-336'
# Specify the pretrained pth
pretrained_pth = '/home/fusionai/project/internllm_demo/llama3/pretrained-model/llama3-llava-iter_2181.pth'  # noqa: E501

# Data
data_root = '/home/fusionai/project/datasets/llama3_test001/'
data_path = data_root + 'repeated_data.json'
image_folder = data_root
prompt_template = PROMPT_TEMPLATE.llama3_chat
max_length = int(2048 - (336 / 14)**2)

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 1
dataloader_num_workers = 0
max_epochs = 1
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2  # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = ''
evaluation_images = '/home/fusionai/project/datasets/llama3_test001/imgs/test0001.png'
evaluation_inputs = ['此图表示什么逻辑?','图中都有哪些逻辑符号?']

#######################################################################
#            PART 2  Model & Tokenizer & Image Processor              #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=llm_name_or_path,
    trust_remote_code=True,
    padding_side='right')

image_processor = dict(
    type=CLIPImageProcessor.from_pretrained,
    pretrained_model_name_or_path=visual_encoder_name_or_path,
    trust_remote_code=True)

model = dict(
    type=LLaVAModel,
    freeze_llm=True,
    freeze_visual_encoder=True,
    pretrained_pth=pretrained_pth,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=llm_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        quantization_config=dict(
            type=BitsAndBytesConfig,
            load_in_4bit=True,
            load_in_8bit=False,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4')),
    llm_lora=dict(
        type=LoraConfig,
        r=512,
        lora_alpha=256,
        lora_dropout=0.05,
        bias='none',
        task_type='CAUSAL_LM'),
    visual_encoder=dict(
        type=CLIPVisionModel.from_pretrained,
        pretrained_model_name_or_path=visual_encoder_name_or_path),
    visual_encoder_lora=dict(
        type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'))

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
llava_dataset = dict(
    type=LLaVADataset,
    data_path=data_path,
    image_folder=image_folder,
    tokenizer=tokenizer,
    image_processor=image_processor,
    dataset_map_fn=llava_map_fn,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    max_length=max_length,
    pad_image_to_square=True)

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=llava_dataset,
    sampler=dict(
        type=LengthGroupedSampler,
        length_property='modality_length',
        per_device_batch_size=batch_size * accumulative_counts),
    collate_fn=dict(type=default_collate_fn))

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        image_processor=image_processor,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        evaluation_images=evaluation_images,
        system=SYSTEM,
        prompt_template=prompt_template)
]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)

how to solve this problem, waiting for help!
thx

@LZHgrla
Copy link
Collaborator

LZHgrla commented May 8, 2024

@ztfmars Hi

This issues is caused by the mismatch between the version of transformers and peft.

This PR https://github.com/huggingface/peft/pull/1368/files supports the layer_replication for LoraConfig, so we recommend that you can update your peft to v0.10.0 and re-run your merge script.

@ztfmars
Copy link
Author

ztfmars commented May 10, 2024

@ztfmars Hi

This issues is caused by the mismatch between the version of transformers and peft.

This PR https://github.com/huggingface/peft/pull/1368/files supports the layer_replication for LoraConfig, so we recommend that you can update your peft to v0.10.0 and re-run your merge script.

yes, it works! but it have obvious version conflicts between xtuner and lmdeploy on peft, i will try to install another venv env for lmdeploy again and continue.

image

thx very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants