Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref_model not needed in Fine_tune_a_Mistral_7b_model_with_DPO.ipynb #44

Open
alvarobartt opened this issue Jan 31, 2024 · 4 comments
Open

Comments

@alvarobartt
Copy link

Hi here @mlabonne! Congratulations on your awesome work with this course 馃馃徎

After going through Fine_tune_a_Mistral_7b_model_with_DPO.ipynb I realised that there's no need to define the ref_model required by DPO, since when fine-tuning using LoRA, the reference model is not required, as the one without the adapters will be used to compute the logprobs, so you can remove the ref_model and the result will still be the same, but using even less resources.

Finally, as a tip, when using the DPOTrainer for full fine-tunes you can also specify precompute_ref_log_probs to compute those in advance before the actual fine-tune starts, so that the ref_model is not needed either.

@AzizCode92
Copy link

Hey @alvarobartt, thanks a lot for the hints. I am using the above notebook and your suggestion solved my memory issue on google colab.

@corticalstack
Copy link

Yep, if you try to run DPOTrainer when passing the ref model, you get the runtime error below, to fix you can just comment out ref_model in DPOTrainer (and cleanup the declaration of ref_model).

Thanks @mlabonne for this super notebook which gets me started with going beyond SFT with first DPO tune,

File /usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:217, in DPOTrainer.__init__(self, model, ref_model, beta, label_smoothing, loss_type, args, data_collator, label_pad_token_id, padding_value, truncation_mode, train_dataset, eval_dataset, tokenizer, model_init, callbacks, optimizers, preprocess_logits_for_metrics, max_length, max_prompt_length, max_target_length, peft_config, is_encoder_decoder, disable_dropout, generate_during_eval, compute_metrics, precompute_ref_log_probs, dataset_num_proc, model_init_kwargs, ref_model_init_kwargs, model_adapter_name, ref_adapter_name, reference_free)
    214     model = model.merge_and_unload()
    216 if ref_model is not None:
--> 217     raise ValueError(
    218         "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
    219         " model. Please pass `ref_model=None` in case you want to train PEFT adapters."
    220     )
    222 if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
    223     _support_gc_kwargs = hasattr(
    224         args, "gradient_checkpointing_kwargs"
    225     ) and "gradient_checkpointing_kwargs" in list(
    226         inspect.signature(prepare_model_for_kbit_training).parameters
    227     )

ValueError: You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference model. Please pass `ref_model=None` in case you want to train PEFT adapters.

@RGaonkar
Copy link

Thanks @alvarobartt for opening this issue! I faced the same problem and following your suggestion, solved it. I removed the declaration for ref_model as @corticalstack suggested and I further removed the ref_model argument in the DPOTrainer.
Has anyone opened a PR for this fix? If not, I am happy to do so!

@mlabonne
Copy link
Owner

I updated the notebook and removed the ref_model. Please let me know if it broke something, I couldn't test it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants