Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_is_peft_model update to recognise peft submodules, allowing training quantised models with peft submodules #30884

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

ambroser53
Copy link

What does this PR do?

Don't necessarily have peft models as the top-level wrapper for models, especially when working with custom built multi-modal models. For example:

model = AutoModelForVision2Seq.from_pretrained(
  args.pretrained_ckpt,
  torch_dtype=compute_dtype,
  quantization_config=BitsAndBytesConfig(
      load_in_4bit=bits == 4,
      load_in_8bit=bits == 8,
      llm_int8_threshold=6.0,
      int8_quant_skip_modules=int8_quant_skip_modules,
      llm_int8_has_fp16_weight=False,
      bnb_4bit_compute_dtype=compute_dtype,
      bnb_4bit_use_double_quant=True,
      bnb_4bit_quant_type='nf4'  # {'fp4', 'nf4'}
  ) if bits < 16 else None,
  attn_implementation=args.attn_implementation,
)

if (args.use_lora and not resume_from_checkpoint and not ft_checkpoint_dir):
  target_modules = get_target_modules(model.model.text_model, args, bits)
  peft_config = LoraConfig(
      target_modules=target_modules,
      inference_mode=args.inference_mode,
      r=args.lora_r,
      lora_alpha=args.lora_alpha,
      lora_dropout=args.lora_dropout,
      use_dora=args.use_dora
  )
  model.model.text_model = get_peft_model(model.model.text_model, peft_config)

  if args.vit_train:
      target_modules = get_target_modules(model.model.vision_model, args, args.vit_bits, vit=True)
      peft_config = LoraConfig(
          target_modules=target_modules,
          inference_mode=args.inference_mode,
          r=args.vit_lora_r,
          lora_alpha=args.vit_lora_alpha,
          lora_dropout=args.lora_dropout,
          use_dora=args.use_dora_vit
      )
      model.model.vision_model = get_peft_model(model.model.vision_model, peft_config)

  if args.lora_abstractor:
      target_modules = get_target_modules(model.model.connector, args, args.bits)
      peft_config = LoraConfig(
          target_modules=target_modules,
          inference_mode=args.inference_mode,
          r=args.lora_r,
          lora_alpha=args.lora_alpha,
          lora_dropout=args.lora_dropout,
          use_dora=args.use_dora
      )
      model.model.connector = get_peft_model(model.model.connector, peft_config)

This allows the hf trainer to recognise such models as still being peft models and thereby allow quantised training (QLoRA).

Fixes #30878

Before submitting

Who can review?

@younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the submodules support for PEFT + Trainer ! Left one suggestion - what do you think?

src/transformers/trainer.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding! Could you add a test with a dummy model that has peft submodules that behaves correctly with this change but fails on main?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ambroser53 !
For the styling checks, can you try to run pip install -U ".[quality]" and re-run make fixup ?

@ambroser53
Copy link
Author

@younesbelkada I've just done as you asked and it ran successfully but I have no working tree changes to commit.

@younesbelkada
Copy link
Contributor

Thanks ! Hmm I think something is off with our CI currently, let's wait for #30932 being merged first

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this.

I'm a bit confused about the intended behaviour here from the tests

tests/trainer/test_trainer.py Show resolved Hide resolved
tests/trainer/test_trainer.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Have _is_peft_model check if there's any peft submodule/Allow quantised training
4 participants