Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zc): add MetaDiffuser and prompt-dt #771

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

Super1ce
Copy link
Contributor

Add MetaDIffusion and prompt-dt algorithm

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Jan 31, 2024
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more details?

# use the original batch size per gpu and increase learning rate
# correspondingly.
cfg.policy.learn.batch_size // get_world_size(),
# cfg.policy.learn.batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

for epoch in range(cfg.policy.learn.train_epoch):
if get_world_size() > 1:
dataloader.sampler.set_epoch(epoch)
for i in range(cfg.policy.train_num):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"train_num"->"batch_size"?

(prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim)

# prompt_stacked_attention_mask = torch.stack(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these unused lines?

self.returns_condition = returns_condition
self.condition_guidance_w = condition_guidance_w

# def get_loss_weights(self, discount: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these unused lines?

@@ -69,6 +80,52 @@ def n_step_guided_p_sample(

return model_mean + model_std * noise, y

def free_guidance_sample(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add class hints for all arguments, add Overview for functions and classes.


self.embed = nn.Sequential(
nn.Linear((obs_dim * 2 + action_dim + 1) * encoder_horizon, dim * 4),
Mish(),#nn.Mish(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused code.

self._learn_model = model_wrap(self._model, wrapper_name='base')
self._learn_model.reset()

def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data should be collated into batchsize before entering policy._forward_learn.
data type shoule be Dict[str, torch.Tensor].

if self.have_train:
if self.task_id is None:
self.task_id = [0] * self.eval_batch_size
# if data_id is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused lines.

if self._cuda:
data = to_device(data, self._device)

p_s, p_a, p_rtg, p_t, p_mask, timesteps, states, actions, rewards, returns_to_go, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data should be collated into batchsize before entering policy._forward_learn.
data type shoule be Dict[str, torch.Tensor], so that it can be assigned confirmly.

self.returns_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
#nn.Mish(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused code line.


@DATASET_REGISTRY.register('meta_traj')
class MetaTraj(Dataset):
def __init__(self, cfg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add notation for this class and config items.

Interaction serial evaluator class, policy interacts with env. This class evaluator algorithm
with test environment list.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init -> __init__

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants