Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAC jax #300

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open

SAC jax #300

wants to merge 32 commits into from

Conversation

araffin
Copy link

@araffin araffin commented Oct 23, 2022

Description

Missing: benchmark and doc

Adapted from https://github.com/araffin/sbx
Report (3 seeds on 3 MuJoCo envs): https://wandb.ai/openrlbenchmark/cleanrl/reports/SAC-jax---VmlldzoyODM4MjU0

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Oct 23, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback Jun 15, 2023 6:05pm

@araffin
Copy link
Author

araffin commented Oct 23, 2022

@vwxyzjn tests fails because ModuleNotFoundError: No module named 'pygame', not sure why it worked before...

@vwxyzjn
Copy link
Owner

vwxyzjn commented Oct 24, 2022

ModuleNotFoundError: No module named 'pygame' looks really weird... so I investigated a bit further into it. Instead of running poetry lock, I ran poetry add tensorflow-probability and poetry update flax and that seems to make things work.

It turns out the culprit is the following changes

-classic_control = ["pygame (==2.1.0)"]
+classic-control = ["pygame (==2.1.0)"]

We install pygame by pip install gym[classic_control] under the hood with poetry, but for some reason the key of the extra was changes 😓

@araffin
Copy link
Author

araffin commented Oct 24, 2022

@vwxyzjn I think I'm done for the implementation, I added support for constant entropy coeff and for deterministic eval.
I would be happy to receive help for the documentation ;)

@vwxyzjn
Copy link
Owner

vwxyzjn commented Nov 21, 2022

Perhaps it's because in #217 I implemented my own normal distribution I am trying to do the same for SAC...

However if I replaced

def actor_loss(params):
            dist = TanhTransformedDistribution(
                tfd.MultivariateNormalDiag(loc=action_mean, scale_diag=jnp.exp(action_logstd)),
            )
            actor_actions = dist.sample(seed=subkey)
            log_prob = dist.log_prob(actor_actions).reshape(-1, 1)

with the log probability taken from https://github.com/openai/baselines/blob/9b68103b737ac46bc201dfb3121cfa5df2127e53/baselines/common/distributions.py#L238-L241

def actor_loss(params):
            action_mean, action_logstd = actor.apply(params, observations[0:1])
            action_std = jnp.exp(action_logstd)
            actor_actions = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((actor_actions - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            log_prob = log_prob.sum(axis=1, keepdims=True)
            actor_actions = jnp.tanh(actor_actions)

things kind of fall catastrophically... I felt that maybe implementing our own would bring greater transparency but maybe not be necessary...

@vwxyzjn
Copy link
Owner

vwxyzjn commented Nov 21, 2022

Aha! I got it, it's supposed to be the following
image

            action_mean, action_logstd = actor.apply(params, observations)
            action_std = jnp.exp(action_logstd)
            actor_actions = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((actor_actions - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            actor_actions = jnp.tanh(actor_actions)
            log_prob -= jnp.log((1 - jnp.power(actor_actions, 2)) + 1e-6)
            log_prob = log_prob.sum(axis=1, keepdims=True)

Interestingly, the paper seems to say our implementation should have been the following (with the summation)

log_prob -= jnp.log((1 - jnp.power(actor_actions, 2)) + 1e-6).sum(axis=-1).reshape(-1, 1)

but empirically, it doesn't perform as well... @dosssman any thoughts?

@araffin
Copy link
Author

araffin commented Nov 21, 2022

Interestingly, the paper seems to say our implementation should have been the following (with the summation)

Not sure to follow the difference...

You can take a look at how we do it in SB3, I think it is what is described:
https://github.com/DLR-RM/stable-baselines3/blob/c4f54fcf047d7bf425fb6b88a3c8ed23fe375f9b/stable_baselines3/common/distributions.py#L222-L226

@vwxyzjn
Copy link
Owner

vwxyzjn commented Nov 22, 2022

I tried to implement the probability distribution ourselves 0cf0e9e, but hit a performance regression.

image

Looking into the issue deeper, I couldn't quite understand how TanhTransformedDistribution works. Could someone take a look at https://gist.github.com/vwxyzjn/331f896b79d3f829fdfa575be666d2d8, which generates

manually sample actions, manually calculate log prob
  action=2.561650514602661, logprob=55.152984619140625
manually sample actions, calculate log prob from TanhTransformedDistribution
  action=2.561650514602661, logprob=nan
sample actions from `TanhTransformedDistribution`, calculate log prob from TanhTransformedDistribution
  action=2.7475833892822266, logprob=66.45195770263672
sample actions from `TanhTransformedDistribution`, manually calculate log prob
  action=2.7475833892822266, logprob=-inf

I am quite puzzled. TanhTransformedDistribution seems like quite a black box to me. Because tensorflow_probability is written in tensorflow, there is no meaningful code trace in the IDE to understand what's happening inside... And tfp's docs seems to have some issues (e.g., the "view source code on Github" button in https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiag is broken). Maybe we shouldn't use anything from tfp?

@araffin
Copy link
Author

araffin commented Nov 24, 2022

@vwxyzjn run the code with JAX_ENABLE_X64=True and it will solve your issue ;) (results are still slightly different, but that's probably expected, try with different random seeds)
JIT_DISABLE_JIT=1 already partially solves the issue.

I guess the answer to your question is called numerical precision ;).

EDIT: the code from tf distribution is here: https://github.com/tensorflow/probability/blob/bcdf53024ef9f35d81be063093ccfb3a762dab3f/tensorflow_probability/python/bijectors/tanh.py#L70-L81

  # We implicitly rely on _forward_log_det_jacobian rather than explicitly
  # implement _inverse_log_det_jacobian since directly using
  # `-tf.math.log1p(-tf.square(y))` has lower numerical precision.

  def _forward_log_det_jacobian(self, x):
    #  This formula is mathematically equivalent to
    #  `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
    #  stable.
    #  Derivation:
    #    log(1 - tanh(x)^2)
    #    = log(sech(x)^2)
    #    = 2 * log(sech(x))
    #    = 2 * log(2e^-x / (e^-2x + 1))
    #    = 2 * (log(2) - x - log(e^-2x + 1))
    #    = 2 * (log(2) - x - softplus(-2x))
    return 2. * (np.log(2.) - x - tf.math.softplus(-2. * x))

@araffin
Copy link
Author

araffin commented Nov 28, 2022

run the code with JAX_ENABLE_X64=True and it will solve your issue ;) (results are still slightly different, but that's probably expected, try with different random seeds)

@vwxyzjn as a follow up, if you remove the + 1e-6 in your code, you get the same results. Btw, why did you use 1e-6 and not a smaller value?

EDIT: I don't know why precommit fails, it does work locally

@Howuhh
Copy link
Contributor

Howuhh commented Nov 28, 2022

@araffin 1e-6 used on most popular SAC pytorch implementations, I also use it on my research for some reason (and in CORL). I think it's more a matter of reproducibility.

@ffelten
Copy link

ffelten commented Apr 24, 2023

Hi, is there any update/blocking thing on this?

@araffin
Copy link
Author

araffin commented Jun 15, 2023

@vwxyzjn I would need your help again to update the lockfile, I tried to do it locally and poetry destroyed my conda env...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants