-
Notifications
You must be signed in to change notification settings - Fork 389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Length Consistency in LLM Outputs with Token Length based Penalty Loss Functions #559
Conversation
loss = self.loss_fn(shift_logits, shift_labels) | ||
|
||
true_lengths = torch.sum(labels != 0, dim=-1).float() | ||
pred_lengths = torch.sum(torch.argmax(logits, dim=-1) != 0, dim=-1).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why exactly are we checking for 0
here? @Nischaydnk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea was to ignore the PAD token Ids which we generally apply on answers text (max_length_answer). In models like Llama2, etc. Pad token id is 0, so I took it as default. But yeah, it could be possible that token ID may differ for different tokenizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this will not be the case all the time. Actually, most of the time we set it to eos token. This needs to be dynamically set.
So you will need to check the length of the original sample either via pad or better attention mask. And for predicted length you will need to check eos token, or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes right, Thanks for pointing out. I will update the loss function based on the tokenizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most should be in the cfg
already stored, can look at the get_tokenizer
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment, need to handle the length estimation better
@Nischaydnk will you find time to continue working on this? |
closing this for now - please re-open in future |
Reopened the same PR #556 with correct local branch as suggested by @psinger.
Adding support for custom loss functions aimed at improving the length consistency in responses generated by fientuned LLMs. Idea is to make the output lengths of LLMs more reflective of the token lengths observed in the training data. I did several experiments using the loss functions, and noticed very low deviation in performance of models.
The loss functions implemented are:
LengthBasedTACE (Token Averaged Cross Entropy)
LengthBasedSACE (Sample Averaged Cross Entropy)
Sharing some of the experiments I did using these losses to make a comparison with original Cross Entropy Loss:
Evaluation Results:
There could be some randomness involved in eval metric, but I found consistent decrease in LLMs inference time,specially the ones which scores bad & prone to generate bad responses.
These functions uses a length penalty coefficient, in my experiments I found 0.1 coefficient to be most stable one, therefore I kept it as default. This should help close #537