-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Conservative DPO, IPO, and KTO #78
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
assert self.preference_loss in ["dpo", "ipo", "kto"] | ||
if self.preference_loss == "dpo": | ||
loss = ( | ||
-torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * logits) * (1.0 - self.label_smoothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do you have to compute self.ref_policy_kl_penalty * logits
twice in this loss function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to compute that twice. I changed the code.
elif self.preference_loss == "kto": | ||
rewards_kl = self.get_reduced_masked_logps(pi_logprobs - ref_logprobs, labels, average_log_probs=True) | ||
chosen_kl, reject_kl = self.split_output_tensor(rewards_kl) | ||
loss = torch.cat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another nit: I'm not sure what the performance impact of this is but this is creating a tensor and then filling it in with these 2 things followed by a mean -> why not do .sum() on each of these tensors and then divide it ourselves?
I think it also makes it more readable that way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it!
loss = -torch.nn.functional.logsigmoid(chosen_rewards - reject_rewards) | ||
chosen_rewards, reject_rewards = self.split_output_tensor(rewards) | ||
logits = chosen_rewards - reject_rewards | ||
assert self.preference_loss in ["dpo", "ipo", "kto"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you put the implementation of these 3 loss function in separate functions, and then put these functions in a dictionary?
something like:
PREFERENCE_LOSS_FUNCTIONS = {
"dpo": dpo_loss_function,
...
}
and then get it in the model __init__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put this dict outside the class. Alternatively, we can also make it a part of the MegatronGPTDPOModel class.
with this addition, we should also rename dpo -> something else to make it clear we have the option to do other things. currently im thinking of calling this gpt_preference_optimization @odelalleau do you have any thoughts on a better name? |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
I'm not sure we should rename it. I feel like it may cause more trouble and confusion than keeping the DPO name and documenting it as "DPO and related algorithms / variants" => I wouldn't rush a renaming right now given the limited scope of the changes in this PR (and this is not a criticism: it's great that it's limited!) Btw apologies but my full review may come late => feel free to merge without me. I want to spend some time actually looking at these papers, but this won't be possible until next week. At first glance the code structure looks good to me! EDIT: will also need a CHANGELOG entry |
thanks for including kto in this @ertkonuk ! this kto implemention is the kto-paired version in huggingface that assumes access to a paired preference data. the more powerful (and standard) version of kto can work with purely biary data (+1/-1, good/bad) and supports extreme data imbalances (e.g., 5% positive examples and 95% negative examples), and has some minor changes to make training more stable
i think NeMo users would find it more useful to have the latter version of KTO, since it would allow them to align with a much more abundant kind of feedback |
Hi @kawine, Thanks for your feedback and providing the reference implementations. I agree that supporting the unpaired version would be more advantageous and our plan is to eventually have that in NeMo Aligner. I'll begin making the necessary changes to implement the standard version of KTO very soon. |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
It looks like there are large discrepancies between fp32 and bf16 runs at the moment. Looking into it. |
What does this PR do ?
Adds Conservative DPO (CDPO), IPO, and KTO methods
Changelog
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Checklist when contributing a new algorithm
max_steps=-1
andvalidation
?Additional Information