Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO + JAX + EnvPool + MuJoCo #217

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open

PPO + JAX + EnvPool + MuJoCo #217

wants to merge 29 commits into from

Conversation

vwxyzjn
Copy link
Owner

@vwxyzjn vwxyzjn commented Jun 27, 2022

Description

Types of changes

  • Bug fix
  • New feature

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Jun 27, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jan 13, 2023 at 0:46AM (UTC)

@gitpod-io
Copy link

gitpod-io bot commented Jun 27, 2022

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 27, 2022

It seems that there isn't that much benefit in PPO - the SPS metric is not a lot better, as shown below.

image

Note: there is probably a bug... that's why the sample efficiency suffers.

Maybe I was implementing PPO using the incorrect paradigm with JAX. Any thoughts on this @joaogui1 and @ikostrikov? Thanks!

@ikostrikov
Copy link

I'm not sure if

obs = obs.at[step].set(x)

is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that?

Comment on lines 329 to 333
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
lrnow = frac * args.learning_rate
agent_optimizer_state[1].hyperparams["learning_rate"] = lrnow
agent_optimizer.update(agent_params, agent_optimizer_state)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

It turns out these 4 lines of code slow down the throughput by a half. We are going to need a better learning rate annealing paradigm probably using the official API.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, there's a gain if the main for loop can be replaced with lax.fori_loop

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 27, 2022

The latest commit fixes two stupid bug, we now can match the exact same performance :)

image

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jun 27, 2022

I'm not sure if

obs = obs.at[step].set(x)

is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that?

Maybe the documentation meant if you had created an array inside the JIT the operation would be in place? I tested out

print("id(obs) before", id(obs))
obs, dones, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value(
    next_obs, next_done, obs, dones, actions, logprobs, values, step, agent_params, key
)
print("id(obs) after", id(obs))

which gives

id(obs) before 140230683526704
id(obs) after 140230683590064

@ikostrikov
Copy link

@vwxyzjn yes, I think it's either for arrays created inside of jit or donated arguments.

@vwxyzjn vwxyzjn mentioned this pull request Jun 27, 2022
5 tasks
advantages = advantages.at[:].set(0.0) # reset advantages
next_value = critic.apply(agent_params.critic_params, next_obs).squeeze()
lastgaelam = 0
for t in reversed(range(args.num_steps)):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was looking through your codes to get some idea about how other people were writing RL algos in jax (and how far people jited things) and think this might be an issue during the first compile step. The for loop will basically be unrolled and when I tried this the compile time was very long especially if args.num_steps is big.

Ended up using jax.lax.scan and replaced the loop like this (code doesn't fit yours exactly but idea is there):

    not_dones = ~dones
    
    value_diffs = gamma * values[1:] * not_dones - values[:-1]
    deltas = rewards + value_diffs

    def body_fun(gae, t):
        gae = deltas[t] + gamma * gae_lambda * not_dones[t] * gae
        return gae, gae
    indices = jnp.arange(N)[::-1]
    gae, advantages = jax.lax.scan(body_fun, 0.0, indices,)
    advantages = advantages[::-1]

Also avoids using the .at and .set functions (of which im still not sure of what the performance is). Maybe this might be useful.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use reverse=True in the scan so you don't have to flip it.

@vwxyzjn vwxyzjn changed the title Jax ppo envpool JAX + PPO + EnvPool + MuJoCo Jul 12, 2022
@vwxyzjn vwxyzjn changed the title JAX + PPO + EnvPool + MuJoCo PPO + JAX + EnvPool + MuJoCo Jul 12, 2022
@nico-bohlinger
Copy link

Jitting the epochs in update_ppo() results in extremely high start up times for high epoch values and doesn't provide any speed after it's finally running.
Bringing the epoch loop in the main function would fix that, like:

for _ in range(args.update_epochs):
   agent_state, loss, pg_loss, v_loss, approx_kl, key = update_ppo(agent_state, storage, key)

Comment on lines +202 to +206
envs = gym.wrappers.ClipAction(envs)
envs = gym.wrappers.NormalizeObservation(envs)
envs = gym.wrappers.TransformObservation(envs, lambda obs: np.clip(obs, -10, 10))
envs = gym.wrappers.NormalizeReward(envs)
envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is desirable to implement these in jax, which should help speed up the training progress and will allow us to use the XLA interface in the future.

@vwxyzjn vwxyzjn mentioned this pull request Oct 31, 2022
20 tasks
@vwxyzjn vwxyzjn mentioned this pull request Nov 21, 2022
20 tasks
@51616
Copy link
Collaborator

51616 commented Nov 25, 2022

I think it's worth changing to lax.scan and fori_loop. Removing the for loop within rollout increases the speed quite a bit. Significantly reduces the complication time. I can make a pull request for this (and for compute_gae and update_ppo as well). I compared the original rollout and the lax.scan implementation and got the following results:

# Original for loop
Total data collection time: 135.69225978851318 seconds
Total data collection time without compilation: 98.75351285934448 seconds
Approx. compilation time: 36.93875765800476 seconds
# with lax.scan
Total data collection time: 60.91851544380188 seconds
Total data collection time without compilation: 60.029022455215454 seconds
Approx. compilation time: 0.8895087242126465 seconds

The command used is: python cleanrl/ppo_atari_envpool_xla_jax.py --env-id Breakout-v5 --total-timesteps 500000 --num-envs 32

Note: The training code was removed as the collection time correlates with the avg_episodic_length, which depends on the random exploration and training dynamics. Removing the training part makes sure that the numbers in the test only relate to the rollout function.

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Nov 25, 2022

@51616 thanks for raising this issue. Could you share the snippet that derived these numbers?

Does lax.scan reduce the rollout time after compilation is finished? nvm I misread something. It’s interesting the rollout time after compilation is much faster, and this would be a good reason to consider using scan. Would you mind preparing the PR?

@51616
Copy link
Collaborator

51616 commented Nov 26, 2022

@vwxyzjn Here's the code

    def step_once(carry, step, env_step_fn):
        (agent_state, episode_stats, next_obs, next_done, storage, key, handle) = carry
        storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)
        episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)
        storage = storage.replace(rewards=storage.rewards.at[step].set(reward))
        return ((agent_state, episode_stats, next_obs, next_done, storage, key, handle), None)
    
    def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step,
                step_once_fn, max_steps):
        
        (agent_state, episode_stats, next_obs, next_done, storage, key, handle), _ = jax.lax.scan(
            step_once_fn,
            (agent_state, episode_stats, next_obs, next_done, storage, key, handle), (), max_steps)
        
        global_step += max_steps * args.num_envs
        return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
    
    rollout_fn = partial(rollout,
                         step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed),
                         max_steps=args.num_steps)
    
    for update in range(1, args.num_updates + 1):
        update_time_start = time.time()
        agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step = rollout_fn(
            agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
        )
        if update == 1:
            start_time_wo_compilation = time.time()
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        print("SPS_update:", int(args.num_envs * args.num_steps / (time.time() - update_time_start)))
        writer.add_scalar(
            "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step
        )
    print("Total data collection time:", time.time() - start_time, "seconds")
    print("Total data collection time without compilation:", time.time() - start_time_wo_compilation, "seconds")
    print("Approx. compilation time:", start_time_wo_compilation - start_time, "seconds")
    envs.close()
    writer.close()

I can make a PR for this. I also think we should use the output of the lax.scan as opposed to replacing the value inplace. Might look something like this

    def step_once(carry, step, env_step_fn):
        (agent_state, episode_stats, obs, done, key, handle) = carry
        action, logprob, value, key = get_action_and_value(agent_state, obs, key)
        
        episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)
        
        storage = Storage(
            obs=obs,
            actions=action,
            logprobs=logprob,
            dones=done,
            values=value,
            rewards=reward,
            returns=jnp.zeros_like(reward),
            advantages=jnp.zeros_like(reward),
        )
        
        return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage)
    
    def rollout(agent_state, episode_stats, next_obs, next_done, key, handle,
                step_once_fn, max_steps):
        
        (agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan(
            step_once_fn,
            (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps)
        
        return agent_state, episode_stats, next_obs, next_done, key, handle, storage
    
    rollout_fn = partial(rollout,
                         step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed),
                         max_steps=args.num_steps)
    
    for update in range(1, args.num_updates + 1):
        update_time_start = time.time()
        agent_state, episode_stats, next_obs, next_done, key, handle, storage = rollout_fn(
            agent_state, episode_stats, next_obs, next_done, key, handle
        )
        if update == 1:
            start_time_wo_compilation = time.time()
        global_step += args.num_steps * args.num_envs
        ...

The code is a bit cleaner and uses the output from lax.scan directly

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Jan 13, 2023

image

@pseudo-rnd-thoughts
Copy link
Collaborator

@vwxyzjn Was there any reason why this wasn't merged in the end?

@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Aug 30, 2023

Nothing really. If you’d like free free to take on the PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants