Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add rlhf pipeline. #748

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Conversation

kxzxvbk
Copy link
Contributor

@kxzxvbk kxzxvbk commented Nov 6, 2023

Description

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added enhancement New feature or request algo Add new algorithm or improve old one labels Nov 6, 2023
@@ -18,6 +19,7 @@
from .model import PPOFModel
from .config import get_instance_config, get_instance_env, get_hybrid_shape
from ding.bonus.common import TrainingReturn, EvalReturn
from ..framework.middleware.collector import ChatCollector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge it into ding.framework

"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why indent here


def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
"""
Filter a distribution of logits using nucleus (top-p) filtering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish comments add add unittest

if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[:, :min_topk] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

..., :min_topk

@@ -1,4 +1,7 @@
from typing import Union, Dict, Optional

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these modifications to a new single file: lm_vac.py


def __init__(self, config, opt, tokenizer):
super().__init__(config)
self.opt = opt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define opt here

else:
logits = self.reward_head(output.last_hidden_state).squeeze(-1)

return (logits, )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why return a tuple here

self._init_flag = False

def reset(self):
self.last_batch = next(self.generator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to restrat generatore here?


class LlamaRewardModel(LlamaForCausalLM):

def __init__(self, config, opt, tokenizer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the creation of tokenizer insides the constructor of RM?

@@ -0,0 +1,50 @@
from easydict import EasyDict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it to dizoo/chat/entry

Copy link

codecov bot commented Jan 3, 2024

Codecov Report

Attention: 252 lines in your changes are missing coverage. Please review.

Comparison is base (d7a61c2) 76.78% compared to head (f3a8245) 76.83%.

Files Patch % Lines
ding/model/template/lm_vac.py 20.00% 92 Missing ⚠️
ding/policy/ppof.py 5.74% 82 Missing ⚠️
ding/framework/middleware/collector.py 15.62% 27 Missing ⚠️
ding/rl_utils/gae.py 11.11% 16 Missing ⚠️
ding/reward_model/language_reward_model.py 31.57% 13 Missing ⚠️
ding/bonus/ppof.py 0.00% 12 Missing ⚠️
ding/bonus/config.py 0.00% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #748      +/-   ##
==========================================
+ Coverage   76.78%   76.83%   +0.04%     
==========================================
  Files         671      674       +3     
  Lines       53196    53935     +739     
==========================================
+ Hits        40847    41440     +593     
- Misses      12349    12495     +146     
Flag Coverage Δ
unittests 76.83% <20.50%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants