Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPOClip grad update seems to cause inf update #5

Open
glmcdona opened this issue Oct 22, 2021 · 3 comments
Open

PPOClip grad update seems to cause inf update #5

glmcdona opened this issue Oct 22, 2021 · 3 comments

Comments

@glmcdona
Copy link

glmcdona commented Oct 22, 2021

Describe the bug
Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

Expected behavior

...
adv = np.random.rand(32)
grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
print("grads", grads)
print(ppo_clip._pi.params)
metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
print(ppo_clip._pi.params)

Results in:

grads FlatMapping({
  'linear': FlatMapping({
              'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
              'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                ...,
                                [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
            }),
FlatMapping({
  'linear': FlatMapping({
              'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                ...,
                                [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
              'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
            }),
})
FlatMapping({
  'linear': FlatMapping({
              'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
              'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                ...,
                                [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
            }),

Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

import os
from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
from luxai2021.game.game import Game
from luxai2021.game.actions import *
from luxai2021.game.constants import LuxMatchConfigs_Default

from luxai2021.env.agent import Agent, AgentWithTeamModel
import numpy as np

from agent import TeamAgent

# set some env vars
os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet

import gym
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam


# the name of this script
name = 'ppo'

configs = LuxMatchConfigs_Default

player = TeamAgent(mode="train")
opponent = Agent()

env = LuxEnvironment(configs=configs,
                                learning_agent=player,
                                opponent_agent=opponent)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

def func_pi(S, is_training):
    n_actions = 4
    out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
    return out

def func_v(S, is_training):
    h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
    return h

'''
def func_pi(S, is_training):
    #print(env.action_space.shape)
    n_filters = 5
    n_actions = 4
    n_layers = 3

    h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
    for layer in range(n_layers):
        h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
    
    print('h', type(h), h.shape)
    h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
    h_head_actions = hk.Linear(n_actions)(h_head)
    print('h_head_actions', type(h_head_actions), h_head_actions.shape)
    #print(h_head_actions)

    out = {'logits': h_head_actions}
    
    return out

def func_v(S, is_training):
    n_filters = 5
    n_layers = 3

    h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
    for layer in range(n_layers):
        h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))

    h = hk.Flatten()(h)
    h = jax.nn.relu(hk.Linear(64)(h))
    h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
    
    return h
'''


# function approximators
pi = coax.Policy(func_pi, env)
v = coax.V(func_v, env)

# target networks
pi_behavior = pi.copy()
v_targ = v.copy()

# policy regularizer (avoid premature exploitation)
entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)

# updaters
simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))

# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)

# run episodes
max_episode_steps = 400
while env.T < 3000000:
    s = env.reset()

    for t in range(max_episode_steps):
        print(t)
        a, logp = pi_behavior(s, return_logp=True)
        s_next, r, done, info = env.step(a)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done, logp)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= buffer.capacity:
            num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
            for i in range(num_batches):
                transition_batch = buffer.sample(32)
                grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)

                
                adv = np.random.rand(32)
                grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                print("grads", grads)
                print(ppo_clip._pi.params)
                metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                print(ppo_clip._pi.params)
                exit()
                env.record_metrics(metrics_pi)
                env.record_metrics(metrics_v)
                

            buffer.clear()

            # sync target networks
            pi_behavior.soft_update(pi, tau=0.1)
            v_targ.soft_update(v, tau=0.1)

        if done:
            break

        s = s_next

    # generate an animated GIF to see what's going on
    if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
        T = env.T - env.T % 10000  # round to 10000s
        coax.utils.generate_gif(
            env=env, policy=pi, resize_to=(320, 420),
            filepath=f"./data/gifs/{name}/T{T:08d}.gif")

@royerk
Copy link

royerk commented Oct 22, 2021

Hello,

To add to @glmcdona, I'm getting the exact same issue but with a Box action space (if that makes any difference). After the update with the first minibatch the networks are filled with nans.

I will try to replicate with a classic gym env (by the way the pendulum-v0 from the examples is deprecated I think).

@glmcdona
Copy link
Author

This error only occurs with the optax adam optimizer. Workaround is to use sgd optimizer. Error does not reproduce with TestPPOClip->test_update_discrete() or the example pong PPO with adam optimizer. Maybe close this issue unless a reliable repro can be created?

@KristianHolsheimer
Copy link
Contributor

Hi Geoff! Thanks for telling me about this one.

It's very surprising that replacing optax.adam by optax.sgd seems to help. Perhaps the adam accumulators are contaminated by one a non-finite gradient somewhere?

Would it be possible to share a Colab notebook?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants