Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing issue 46 #136

Closed
wants to merge 3 commits into from
Closed

fixing issue 46 #136

wants to merge 3 commits into from

Conversation

guyknvda
Copy link
Collaborator

What does this PR do ?

Apply sampling params to the logprobs of the response tokens (see issue #46)

The application of sampling params is done by default.
to be consistent with the response generation process (done in text_generation_utils.py )
the following parameters were taken into account:

  • temperature
  • top_p
  • top_k

note that:

  1. if use_greedy is set to True (default), the generation doesnt change the logits, thus the original logits are used to compute the log prob, ignoring the other sampling params (top_p, temperature and top_k)
  2. the repetition_penalty is currently not taken into account since during the generation, it is also not taken into account (potential issue. its only taken into account if compute_logprob is True).

Additional Information

Signed-off-by: gkoren <gkoren@nvidia.com>
Signed-off-by: gkoren <gkoren@nvidia.com>
Copy link
Collaborator

@odelalleau odelalleau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sincere apologies for the late review!

I really appreciate you tackling this issue, which is not trivial. Unfortunately it's more complex than that, for two reasons:

  1. We need to handle tensor parallelism, which means that modifying the logits probably needs to be done within DistributedLogprob. It's likely be going to be a bit tricky to implement but it should be doable, at the expense of a few more steps to handle top_k / top_p. Note that it may be more efficient (and less memory intensive) to gather only the top_k logits from each rank.
  2. We need to also modify the logits used in the loss here

I also think we should add flags to control where exactly these transformations are applied. I'm actually not sure it's a good idea to apply it to compute the KL penalty term because:

  • If we apply it to the reference policy, it may lead to infinite KL due to top_p / top_k (when we sample a token that has zero probability under the reference policy)
  • If we don't apply it to the reference policy, then we may start with a high KL penalty from the start, which could cause some issues.

I would thus suggest to add some fine-grained control on where we apply this transformation, with the following default values:

model:
  ppo:
    transform_logits_from_sampling_params:
      loss: True
      kl_penalty_actor: False
      kl_penalty_ref: ${.kl_penalty_actor}

This way we will be able to easily experiment with various configurations to see what actually works best in practice.

# apply the sampling params to the logits - focusing only on the generated tokens.
context_length = context_lengths.min().item()
resp_logits = logits[:, context_length - 1 :].contiguous()
if not samparams.get("use_greedy", False): # if use_greedy is True, use the logits as is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion: move this up two lines and write it

if samparams.get("use_greedy", False):
    return logits

which will avoid a couple of useless ops & extra indent

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: I think we should add a check to skip scaling if temp == 1, and top_p / top_k if they are equal to 1.0 (or 0.0) / 1. This way we don't mess with logits for no good reason.

@guyknvda
Copy link
Collaborator Author

replaced by new PR #186

@guyknvda guyknvda closed this May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants