Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #243

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

danielpalen
Copy link

@danielpalen danielpalen commented May 4, 2024

This PR implements CrossQ (https://openreview.net/pdf?id=PczQtTsTIX), a novel off-policy deep RL algorithm that carefully uses batch normalisation and removes target networks to achieve state-of-the-art sample efficiency at a much lower computational complexity, as it does not require large update-to-data-ratios.

Description

This implementation is a PyTorch implementation based on the original JAX implementation (https://github.com/adityab/CrossQ).
The following plot shows that the performance matches the performance reported in the original paper, as well as the performance of the open source SBX implementation provided by the authors (evaluated on 10 seeds).

sbx_reproduce

Context

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • The functionality/performance matches that of the source (required for new training algorithms or training-related features).
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have included an example of using the feature (required for new features).
  • I have included baseline results (required for new training algorithms or training-related features).
  • I have updated the documentation accordingly.
  • I have updated the changelog accordingly (required).
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

@araffin araffin self-requested a review May 4, 2024 19:43
@danielpalen
Copy link
Author

@araffin in my initial PR it seams like one code style check was failing, sorry about that. I fixed it and it passes on my machine now. I hope it will go through now :)

.. autosummary::
:nosignatures:

MlpPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add at least the multi input policy? (so we can try it in combination with HER)
Only the feature extractor should be changed normally.

And what do you think about adding CnnPolicy?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I looked into it and have not yet added it. If I am not mistaken this would also require some changes to the CrossQ train() function. Since, now concatenating and splitting the batches would also require some control flow based on the used policy.
For simplicity sake (for now) and since I did not have time to try and evaluate the multi input policy I did not add that yet.

latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)

if batch_norm:
# If batchnorm, then we want to add torch.nn.Batch_Norm layers before every linear layer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about updating create_mlp to allow to pass normalization layer/dropout?

Similar to what is done in DLR-RM/stable-baselines3#1036 and proposed in DLR-RM/stable-baselines3#1069

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would make sense, because the way I implemented it right now is really that nice.


with th.no_grad():
# Select action according to policy
self.actor.set_training_mode(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that needed? self.actor.set_training_mode(False) is already set above?
or you meant self.actor.set_training_mode(True)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more mode calls than needed. The reason was, that I wanted to be very specific which more needs to be used where. I think using the wrong BN mode is one of the big gotchas and sources of error when implementing CrossQ. Since this here should be a PyTorch reference to aid others when they want to implement it by themselves I think it is helpful to make the mode very specific to clear up possible confusion.

self.critic.optimizer.step()

# Compute actor loss
self.critic.set_training_mode(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed but maybe for later, we should probably deactivate only the batchnorm? (for instance if dropout is used, we want it to be active there)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thinking as above. But you are right, if we want to additionally use dropout we should adapt this. Maybe we can just have a set_bn_training_mode function. However, that would be very specific for our use case.

@araffin
Copy link
Member

araffin commented May 6, 2024

Thanks a lot for the implementation =)

I'll try later in the week, but how is it in term of runtime? (SAC vs CrossQ in PyTorch)

@danielpalen
Copy link
Author

danielpalen commented May 12, 2024

No worries :)

I just pushed most things you requested. I'll add some more specific responses directly to the questions above.

how is it in term of runtime? (SAC vs CrossQ in PyTorch)

It seems to be quite a but slower than the SAC baseline (and the JAX implementation as well).
for 4M steps, SAC HumanoidStandup took around 12 hours whereas CrossQ took 22 hours. Not sure if there are some PyTorch implementation details that could help with speed.

@araffin
Copy link
Member

araffin commented May 17, 2024

I'm suspecting something is wrong with the current implementation (I'm currently investigating if it is my changes or not).
My setting:

BipedalWalker-v3:
  n_timesteps: !!float 2e5
  policy: 'MlpPolicy'
  buffer_size: 300000
  gamma: 0.98
  learning_starts: 10000
  policy_kwargs: "dict(net_arch=dict(pi=[256, 256], qf=[1024, 1024]))"

With the RL Zoo cli for both SBX and SB3 (see SBX readme to have support)

python train.py --algo crossq --env BipedalWalker-v3 -P --verbose 0 -param n_envs:30 gradient_steps:30 -n 200000

I'm getting much better results with SBX...
I hope it is not the Adam parameters.

@danielpalen
Copy link
Author

Did you figure out what the issue is? I was at ICRA until last week so I didn't have time but if you didn't find it yet I can also have a look.

Before I pushed my last commit I benchmarked it and there the results looked as expected.

@araffin araffin added the Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;) label May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Implement CrossQ
2 participants