-
Notifications
You must be signed in to change notification settings - Fork 290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add VQ-BeT #166
base: main
Are you sure you want to change the base?
Add VQ-BeT #166
Conversation
Merge from main
Merge from main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jayLEE0301 thanks so much for being the first to PR a model to LeRobot! The paper for VQ-BeT was a really nice read.
So, for the review. I've left a bunch of comments (many of them nits, but some blockers), and actually decided to stop reviewing partway through. That's because I noticed there are some high-level points I can share here. So instead, please consider these high level comments as my primary review, and my inline comments as examples to support.
So, our goal is to make this code highly accessible to the community, meaning it's easy to read and understand, and is easily hackable. A side effect of aiming for these goals is usually that the code is maintainable.
With that overarching goal in mind here are 3 high level points:
-
Consider the
VQBeTPolicy
class as the only "public" object in the modeling file. Everything else is there for the sole purpose ofVQBeTPolicy
. This means:- Go minimalist. We should drop any kwargs, conditional branching, or other logic that is unused. The other functions and logic should only be as dynamic as needed to serve
VQBeTPolicy
. Rule of thumb: if it can't be be activated via the configuration parameters, it can go - Use the
config
instead of many kwargs. Most of the other modules can take aconfig
argument and make aself.config
(avoids relisting parameters twice, and makes it that there's one source of truth for what the params mean - no need to repeat documentation or type-hinting).
- Go minimalist. We should drop any kwargs, conditional branching, or other logic that is unused. The other functions and logic should only be as dynamic as needed to serve
-
Consolidate code: We want to avoid too much nesting or duplication of code. Consider for example my inline comment about the MLPs. I think it's reasonable to use one class for MLPs (and it can be simpler and shorter than the 3 existing classes now). This is just an example though, there may be more opportunities for consolidation.
-
Documentation and naming: We want to make sure that everything is well understood by a first-time reader. Wear the hat of someone who has read through your paper once, and enters the code via the
VQBeTPolicy
class. They should be able to traverse the submodule hierarchy, understanding what everything is as they go. And they should be able to make sense of what's happening in the forward function.- Above all, please make sure the
VQBeTConfig
documentation is solid. - Please add docstrings to classes and methods when it wouldn't be obvious what they are in relation to the main policy and paper.
- Please separate long methods into logical blocks with comments so that one doesn't get lost along they way. (btw: this doesn't mean separating them into smaller functions)
- Please make sure it's easy to follow what's happening with tensor dimensions.
einops
is also helpful for that. - Favor full words over abbreviations:
embd
->embed
and try to match the terminology/naming in your paper.
- Above all, please make sure the
When in doubt, please take inspiration from LeRobot's ACT and TD-MPC (Diffusion Policy is good too but may need a little more work).
Debugging vqbet
merge changes from HF repo
Pull main branch
Thank you for the review! Following these high-level points,
for all the parts of this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checkpoint. I will continue next week.
|
||
|
||
# queues are populated during rollout of the policy, they contain the n latest observations and actions | ||
self._queues = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps a call to self.reset() at the bottom of the __init__
would be more appropriate? See
self.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added self.reset() at the bottom of __init__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance we can also drop self._queues = None
please? It doesn't hurt from a logic perspective, but it does potentially confuse someone who will wonder why there's a redundant line of code here.
features = self.policy(observation_feature) | ||
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode | ||
|
||
# only extract the output tokens at the position of action query |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: If this is the case, what function to the other action tokens serve other than to increase compute for a forward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This increases computation, but can help improve overall learning performance, and avoiding overfitting (not always)
You can think of it similar to predicting a longer sequence of actions in a diffusion policy compared to the actual sequence of actions to be performed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Mind adding that in as a comment (if it's not already mentioned in your paper)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added
Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional).
spatial_softmax_num_keypoints: int = 32 | ||
# VQ-VAE | ||
discretize_step: int = 3000 | ||
vqvae_groups: int = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some places of the code, this is statically handled, meaning changing this number will break things. Can we please either remove it as a parameter or make sure the code can handle it dynamically?
One example is the cbet_loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the code can handle vqvae_groups
dynamically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still seeing the use of "primary" and "secondary" in the code. For example VQBeTOptimizer.__init__
. Am I misunderstanding something?
Merge from hf lerobot main branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just publishing my responses in a batch. Thanks for resolving these :D
Now moving on with the review.
|
||
|
||
# queues are populated during rollout of the policy, they contain the n latest observations and actions | ||
self._queues = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance we can also drop self._queues = None
please? It doesn't hurt from a logic perspective, but it does potentially confuse someone who will wonder why there's a redundant line of code here.
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) | ||
if not self.check_discretized(): | ||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action']) | ||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think I understand this. Can you let me know if my understanding is correct?
n_different_codes: how many of the total possible VQ codes are being used (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed`.
n_different_combinations: how many different code combinations are being used out of all possible combinations. This can be at most `vqvae_n_embed ^ vqvae_groups` (hint consider the RVQ as a decision tree).
But shouldn't `n_different_codes` max out at `vqvae_n_embed * vqvae_groups`? That's how many codes there are in total. Or are you only referring to the codes of the first RVQ layer?
Btw: I think this is a great metric to track!
features = self.policy(observation_feature) | ||
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode | ||
|
||
# only extract the output tokens at the position of action query |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Mind adding that in as a comment (if it's not already mentioned in your paper)?
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer). | ||
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT, | ||
and the output dimension of `self.map_to_cbet_preds_bin` is `self.config.vqvae_groups * self.config.vqvae_n_embed`, where | ||
`self.config.vqvae_groups` is number of RVQ layers, and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Same as an earlier revision above, can we please remove these duplicated explanations of what these variables mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the duplicated parts:) Thank you
spatial_softmax_num_keypoints: int = 32 | ||
# VQ-VAE | ||
discretize_step: int = 3000 | ||
vqvae_groups: int = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still seeing the use of "primary" and "secondary" in the code. For example VQBeTOptimizer.__init__
. Am I misunderstanding something?
} | ||
return loss_dict | ||
|
||
class VQBeTOptimizer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I please ask that we consolidate and simplify here? Consider where we can get away with using one optimizer instead of many. I count 4 optimizers being initialized here and I'm not sure all of them are needed. I'll let you double check, but I think we might be able to get away with 2 or even just 1 (if you no_grad the quantizer when the discretization is done).
Feel free to let me know if this is not possible. I checked briefly, but not exhaustively.
At a higher level, we have a plan to have some way of the policy code providing the optimizer and scheduler. So I think you have made a good step towards that here. Right now we have train.py
handling this logic and that's not nice. Ideally, what I think we want here is one method in the top-level policy class make_optimizer
which handles everything. That way train.py
can just call make_optimizer
without having to know which specific policy it is. Here, this would mean taking Karpathy's configure optimizers
logic and consolidating it into that same make_optimizer
class. I don't think we want the optimizer creation distributed throughout various modules of the file.
Happy to get your input on all these thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestions!
I removed all the redundant optimizers, and merged all the optimizers for phase 1 and phase 2 into one, leaving only one optimizer. I also deleted def step
, def zero_grad
. (and put all the parameters for phase 2 in the same scheduler.)
We haven't done much analysis on how this affects the stability of training at this time, but (after running two seeds) we have found that it can produce similar performance to the uploaded model(https://huggingface.co/JayLee131/vqbet_pusht) based on the best checkpoint.
Perhaps a more diverse hyperparameter search may be needed in the future.
class VQBeTOptimizer(torch.optim.Adam):
def __init__(self, policy, cfg):
vqvae_params = (
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(policy.vqbet.rgb_encoder.parameters())
+ list(policy.vqbet.state_projector.parameters())
+ list(policy.vqbet.rgb_feature_projector.parameters())
+ [policy.vqbet._action_token]
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
)
optim_groups = [
{
"params": decay_params,
"weight_decay": cfg.training.adam_weight_decay,
"lr": cfg.training.lr,
},
{
"params": vqvae_params,
"weight_decay": 0.0001,
"lr": cfg.training.vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
"lr": cfg.training.lr,
},
]
super(VQBeTOptimizer, self).__init__(
optim_groups,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
)
else: | ||
self.eval() | ||
|
||
def draw_logits_forward(self, encoding_logits): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add a docstring here or change the function name to something more apparent? I'm not sure what it means to draw logits forward.
Note: I think most of the function names are self-explanatory, so I really do just mean this one and draw_code_forward
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your opinion:) I removed def draw_logits_forward
since it is not used, and changed def draw_code_forward
to def get_embeddings_from_code
… of all parameters in phase 2 together
…k in class ResidualVQ as resigered buffer
…raw_code_forward to def get_embeddings_from_code
…ebook_vector_from_indices
Thank you for the review! I've resolved all the comments. In high-level view,
|
Merge from HF main branch
What this does
Add VQ-BeT for PushT env.
How it was tested
Explain/show how you tested your changes.
Examples:
configuration_vqbet.py
andmodeling_vqbet.py
invqbet
folder.How to checkout & try? (for the reviewer)
Examples:
This change is