Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Conservative DPO, IPO, and KTO #78

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open

Add Conservative DPO, IPO, and KTO #78

wants to merge 36 commits into from

Conversation

ertkonuk
Copy link
Collaborator

What does this PR do ?

Adds Conservative DPO (CDPO), IPO, and KTO methods

Changelog

  • Please update the CHANGELOG.md under next version with high level changes in this PR.

Usage

  • For CDPO, simply set the model.dpo.label_smoothing to a positive non-zero value. For IPO, and KTO, set the model.dpo.preference_loss to "ipo" or "kto", respectively.
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

Checklist when contributing a new algorithm

  • Does the trainer resume and restore model state all states?
  • Does the trainer support all parallelism techniques(PP, TP, DP)?
  • Does the trainer support max_steps=-1 and validation?
  • Does the trainer only call APIs defined in alignable_interface.py?
  • Does the trainer have proper logging?

Additional Information

  • Related to # (issue)

@ertkonuk ertkonuk changed the base branch from dev to main January 11, 2024 00:33
assert self.preference_loss in ["dpo", "ipo", "kto"]
if self.preference_loss == "dpo":
loss = (
-torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * logits) * (1.0 - self.label_smoothing)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you have to compute self.ref_policy_kl_penalty * logits twice in this loss function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to compute that twice. I changed the code.

elif self.preference_loss == "kto":
rewards_kl = self.get_reduced_masked_logps(pi_logprobs - ref_logprobs, labels, average_log_probs=True)
chosen_kl, reject_kl = self.split_output_tensor(rewards_kl)
loss = torch.cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another nit: I'm not sure what the performance impact of this is but this is creating a tensor and then filling it in with these 2 things followed by a mean -> why not do .sum() on each of these tensors and then divide it ourselves?

I think it also makes it more readable that way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it!

loss = -torch.nn.functional.logsigmoid(chosen_rewards - reject_rewards)
chosen_rewards, reject_rewards = self.split_output_tensor(rewards)
logits = chosen_rewards - reject_rewards
assert self.preference_loss in ["dpo", "ipo", "kto"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put the implementation of these 3 loss function in separate functions, and then put these functions in a dictionary?

something like:

PREFERENCE_LOSS_FUNCTIONS = {
"dpo": dpo_loss_function,
...
}

and then get it in the model __init__?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put this dict outside the class. Alternatively, we can also make it a part of the MegatronGPTDPOModel class.

@gshennvm
Copy link
Collaborator

with this addition, we should also rename dpo -> something else to make it clear we have the option to do other things.

currently im thinking of calling this gpt_preference_optimization @odelalleau do you have any thoughts on a better name?

@gshennvm
Copy link
Collaborator

@ertkonuk please also make sure that we reflect this change on our DPO tutorial, and I will ping you about adding these loss functions onto our CI once we finalize the review.

@github-actions github-actions bot removed the CI label Jan 12, 2024
@odelalleau
Copy link
Collaborator

odelalleau commented Jan 12, 2024

with this addition, we should also rename dpo -> something else to make it clear we have the option to do other things.

currently im thinking of calling this gpt_preference_optimization @odelalleau do you have any thoughts on a better name?

I'm not sure we should rename it. I feel like it may cause more trouble and confusion than keeping the DPO name and documenting it as "DPO and related algorithms / variants" => I wouldn't rush a renaming right now given the limited scope of the changes in this PR (and this is not a criticism: it's great that it's limited!)

Btw apologies but my full review may come late => feel free to merge without me. I want to spend some time actually looking at these papers, but this won't be possible until next week.

At first glance the code structure looks good to me!

EDIT: will also need a CHANGELOG entry

@kawine
Copy link

kawine commented Feb 16, 2024

thanks for including kto in this @ertkonuk !

this kto implemention is the kto-paired version in huggingface that assumes access to a paired preference data. the more powerful (and standard) version of kto can work with purely biary data (+1/-1, good/bad) and supports extreme data imbalances (e.g., 5% positive examples and 95% negative examples), and has some minor changes to make training more stable

  • huggingface is close to merging the PR for the unpaired version (see here)
  • open-rlhf has already implemented this version (as has our own repo)

i think NeMo users would find it more useful to have the latter version of KTO, since it would allow them to align with a much more abundant kind of feedback

@ertkonuk
Copy link
Collaborator Author

Hi @kawine,

Thanks for your feedback and providing the reference implementations. I agree that supporting the unpaired version would be more advantageous and our plan is to eventually have that in NeMo Aligner. I'll begin making the necessary changes to implement the standard version of KTO very soon.

@github-actions github-actions bot added the Utils label Mar 23, 2024
@SahilJain314
Copy link
Collaborator

It looks like there are large discrepancies between fp32 and bf16 runs at the moment. Looking into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants