Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Length Consistency in LLM Outputs with Token Length based Penalty Loss Functions #559

Closed
wants to merge 2 commits into from

Conversation

Nischaydnk
Copy link

Reopened the same PR #556 with correct local branch as suggested by @psinger.

Adding support for custom loss functions aimed at improving the length consistency in responses generated by fientuned LLMs. Idea is to make the output lengths of LLMs more reflective of the token lengths observed in the training data. I did several experiments using the loss functions, and noticed very low deviation in performance of models.

The loss functions implemented are:

LengthBasedTACE (Token Averaged Cross Entropy)
LengthBasedSACE (Sample Averaged Cross Entropy)

Sharing some of the experiments I did using these losses to make a comparison with original Cross Entropy Loss:

Evaluation Results:

There could be some randomness involved in eval metric, but I found consistent decrease in LLMs inference time,specially the ones which scores bad & prone to generate bad responses.

Model Loss Function Time Taken (min) Eval Metric
llama13B-Chat Token Avg CE Loss 40.45 0.810
llama13B-Chat TokenLengthPenalty Token Avg 38.62 0.802
llama7B-Chat Token Avg CE Loss 12.50 0.7684
llama7B-Chat TokenLengthPenalty Token Avg 12.12 0.7484
Yi-6B-Chat Token Avg CE Loss 18.50 0.792
Yi-6B-Chat TokenLengthPenalty Token Avg 15.44 0.785
llama13B-Chat Token Avg CE Loss 78.20 0.728
llama13B-Chat TokenLengthPenalty Token Avg 76.60 0.744
Yi-6B-Chat Token Avg CE Loss 24.44 0.712
Yi-6B-Chat TokenLengthPenalty Token Avg 24.20 0.704

These functions uses a length penalty coefficient, in my experiments I found 0.1 coefficient to be most stable one, therefore I kept it as default. This should help close #537

loss = self.loss_fn(shift_logits, shift_labels)

true_lengths = torch.sum(labels != 0, dim=-1).float()
pred_lengths = torch.sum(torch.argmax(logits, dim=-1) != 0, dim=-1).float()
Copy link
Collaborator

@psinger psinger Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why exactly are we checking for 0 here? @Nischaydnk

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to ignore the PAD token Ids which we generally apply on answers text (max_length_answer). In models like Llama2, etc. Pad token id is 0, so I took it as default. But yeah, it could be possible that token ID may differ for different tokenizers.

Copy link
Collaborator

@psinger psinger Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this will not be the case all the time. Actually, most of the time we set it to eos token. This needs to be dynamically set.

So you will need to check the length of the original sample either via pad or better attention mask. And for predicted length you will need to check eos token, or not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes right, Thanks for pointing out. I will update the loss function based on the tokenizer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most should be in the cfg already stored, can look at the get_tokenizer function.

Copy link
Collaborator

@psinger psinger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment, need to handle the length estimation better

@psinger
Copy link
Collaborator

psinger commented Feb 7, 2024

@Nischaydnk will you find time to continue working on this?

@psinger
Copy link
Collaborator

psinger commented May 21, 2024

closing this for now - please re-open in future

@psinger psinger closed this May 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] Pack sequences in batch
2 participants