Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One of the inputs missing for DPO loss #20

Open
chrisliu298 opened this issue Mar 28, 2024 · 4 comments
Open

One of the inputs missing for DPO loss #20

chrisliu298 opened this issue Mar 28, 2024 · 4 comments

Comments

@chrisliu298
Copy link

chrisliu298 commented Mar 28, 2024

When I use dpo as the forget loss, I encountered the following error:

Traceback (most recent call last):
  File "/root/tofu/forget.py", line 187, in main
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1780, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2118, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3036, in training_step
    loss = self.compute_loss(model, inputs)
  File "/root/tofu/dataloader.py", line 149, in compute_loss
    idk_inputs, forget_inputs, retain_inputs = inputs
ValueError: not enough values to unpack (expected 3, got 2)

After some inspection, I noticed that the inputs variable only contains 2 elements instead of 3, leading to the error. I made no code modification and double-checked that TextForgetDatasetDPOQA is used and has three items returned by __getitem__. However, I wasn't able to trace the source of this error.

@molereddy
Copy link

Also observed the same issue! Wasn't able to to root cause it.

@zhilif
Copy link
Collaborator

zhilif commented Mar 29, 2024

The experiments we run in the paper use the argument forget_loss=idk instead of forget_loss=dpo. The latter wasn't touched for a long time, so I have to spend some time on that. For now if you want to try what we did in the paper, can you use forget_loss=idk?

@chrisliu298
Copy link
Author

I see. Yes I'll use the idk loss for now. Thanks for the clarification.

@molereddy
Copy link

The source of the error is this: DPO requires a different data collator. You can implement DPO by refactoring def custom_data_collator_forget(samples): to have 3 components in output or writing a different data collator for DPO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants