Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add PC+MCTS code #603

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open

Conversation

kxzxvbk
Copy link
Contributor

@kxzxvbk kxzxvbk commented Mar 5, 2023

Description

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Mar 5, 2023
learner=dict(hook=dict(save_ckpt_after_iter=1000)),
train_epoch=20,
),
eval=dict(evaluator=dict(eval_freq=40, ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

increase eval_freq

qbert_pc_mcts_config = dict(
exp_name='pong_pc_mcts_seed0',
env=dict(
manager=dict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unmodified default config

@@ -245,7 +245,10 @@ def eval(
if self._cfg.figure_path is not None:
self._env.enable_save_figure(env_id, self._cfg.figure_path)
self._policy.reset([env_id])
reward = t.info['eval_episode_return']
if 'final_eval_reward' in t.info.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final_eval_reward has been renamed to eval_episode_return

def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
pass

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish methods related to collect

output = {'action': output}
output = default_decollate(output)
# TODO why this bug?
output = [{'action': o['action'].item()} for o in output]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants