Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling JAXAgent train gets stuck if using larger image sizes (inside Ninjax) #85

Open
schneimo opened this issue Aug 29, 2023 · 3 comments

Comments

@schneimo
Copy link

Hi Danijar,

I am currently trying to use higher image resolutions like 256x256 for Dreamer. By simply changing the resolution e.g. for DM control suite, JAX is not able to trace/compile the training function anymore:

python dreamerv3/train.py --logdir logs/test --configs dmc_vision --task dmc_cartpole_swingup --env.dmc.size 256 256

But instead of an error the program seems to be stuck at/after the point where it tries to trace the training function with JAX:

Config:
seed:                                          0                                           (int)
method:                                        name                                        (str)
task:                                          dmc_cartpole_swingup                        (str)
logdir:                                        logs/test                                   (str)
replay:                                        reverb                                      (str)
replay_size:                                   1000000.0                                   (float)
replay_online:                                 False                                       (bool)
replay_save:                                   False                                       (bool)
eval_dir:                                                                                  (str)
filter:                                        .*                                          (str)
jax.platform:                                  gpu                                         (str)
jax.jit:                                       True                                        (bool)
jax.precision:                                 float16                                     (str)
jax.prealloc:                                  True                                        (bool)
jax.debug_nans:                                False                                       (bool)
jax.logical_cpus:                              0                                           (int)
jax.debug:                                     True                                        (bool)
jax.policy_devices:                            [0]                                         (ints)
jax.train_devices:                             [0]                                         (ints)
jax.metrics_every:                             10                                          (int)
run.script:                                    train_eval                                  (str)
run.steps:                                     1250000.0                                   (float)
run.expl_until:                                0                                           (int)
run.log_every:                                 300                                         (int)
run.save_every:                                900                                         (int)
run.eval_every:                                50000.0                                     (float)
run.eval_initial:                              True                                        (bool)
run.eval_eps:                                  10                                          (int)
run.eval_samples:                              1                                           (int)
run.train_ratio:                               512.0                                       (float)
run.train_fill:                                0                                           (int)
run.eval_fill:                                 0                                           (int)
run.log_zeros:                                 False                                       (bool)
run.log_keys_video:                            [image]                                     (strs)
run.log_keys_sum:                              ^$                                          (str)
run.log_keys_mean:                             (log_entropy)                               (str)
run.log_keys_max:                              ^$                                          (str)
run.from_checkpoint:                                                                       (str)
run.sync_every:                                10                                          (int)
run.actor_addr:                                ipc:///tmp/5551                             (str)
run.actor_batch:                               32                                          (int)
envs.amount:                                   8                                           (int)
envs.parallel:                                 process                                     (str)
envs.length:                                   0                                           (int)
envs.reset:                                    True                                        (bool)
envs.restart:                                  True                                        (bool)
envs.discretize:                               0                                           (int)
envs.checks:                                   False                                       (bool)
envs.is_vec:                                   False                                       (bool)
wrapper.length:                                0                                           (int)
wrapper.reset:                                 True                                        (bool)
wrapper.discretize:                            0                                           (int)
wrapper.checks:                                False                                       (bool)
env.atari.size:                                [64, 64]                                    (ints)
env.atari.repeat:                              4                                           (int)
env.atari.sticky:                              True                                        (bool)
env.atari.gray:                                False                                       (bool)
env.atari.actions:                             all                                         (str)
env.atari.lives:                               unused                                      (str)
env.atari.noops:                               0                                           (int)
env.atari.resize:                              opencv                                      (str)
env.dmlab.size:                                [64, 64]                                    (ints)
env.dmlab.repeat:                              4                                           (int)
env.dmlab.episodic:                            True                                        (bool)
env.minecraft.size:                            [64, 64]                                    (ints)
env.minecraft.break_speed:                     100.0                                       (float)
env.dmc.size:                                  [256, 256]                                  (ints)
env.dmc.repeat:                                2                                           (int)
env.dmc.camera:                                -1                                          (int)
env.loconav.size:                              [64, 64]                                    (ints)
env.loconav.repeat:                            2                                           (int)
env.loconav.camera:                            -1                                          (int)
task_behavior:                                 Greedy                                      (str)
expl_behavior:                                 None                                        (str)
batch_size:                                    16                                          (int)
batch_length:                                  64                                          (int)
data_loaders:                                  8                                           (int)
grad_heads:                                    [decoder, reward, cont]                     (strs)
rssm.deter:                                    512                                         (int)
rssm.units:                                    512                                         (int)
rssm.stoch:                                    32                                          (int)
rssm.classes:                                  32                                          (int)
rssm.act:                                      silu                                        (str)
rssm.norm:                                     layer                                       (str)
rssm.initial:                                  learned                                     (str)
rssm.unimix:                                   0.01                                        (float)
rssm.unroll:                                   False                                       (bool)
rssm.action_clip:                              1.0                                         (float)
rssm.winit:                                    normal                                      (str)
rssm.fan:                                      avg                                         (str)
encoder.mlp_keys:                              $^                                          (str)
encoder.cnn_keys:                              image                                       (str)
encoder.act:                                   silu                                        (str)
encoder.norm:                                  layer                                       (str)
encoder.mlp_layers:                            5                                           (int)
encoder.mlp_units:                             1024                                        (int)
encoder.cnn:                                   resnet                                      (str)
encoder.cnn_depth:                             32                                          (int)
encoder.cnn_blocks:                            0                                           (int)
encoder.resize:                                stride                                      (str)
encoder.winit:                                 normal                                      (str)
encoder.fan:                                   avg                                         (str)
encoder.symlog_inputs:                         True                                        (bool)
encoder.minres:                                4                                           (int)
decoder.mlp_keys:                              $^                                          (str)
decoder.cnn_keys:                              image                                       (str)
decoder.act:                                   silu                                        (str)
decoder.norm:                                  layer                                       (str)
decoder.mlp_layers:                            5                                           (int)
decoder.mlp_units:                             1024                                        (int)
decoder.cnn:                                   resnet                                      (str)
decoder.cnn_depth:                             32                                          (int)
decoder.cnn_blocks:                            0                                           (int)
decoder.image_dist:                            mse                                         (str)
decoder.vector_dist:                           symlog_mse                                  (str)
decoder.inputs:                                [deter, stoch]                              (strs)
decoder.resize:                                stride                                      (str)
decoder.winit:                                 normal                                      (str)
decoder.fan:                                   avg                                         (str)
decoder.outscale:                              1.0                                         (float)
decoder.minres:                                4                                           (int)
decoder.cnn_sigmoid:                           False                                       (bool)
reward_head.layers:                            2                                           (int)
reward_head.units:                             512                                         (int)
reward_head.act:                               silu                                        (str)
reward_head.norm:                              layer                                       (str)
reward_head.dist:                              symlog_disc                                 (str)
reward_head.outscale:                          0.0                                         (float)
reward_head.outnorm:                           False                                       (bool)
reward_head.inputs:                            [deter, stoch]                              (strs)
reward_head.winit:                             normal                                      (str)
reward_head.fan:                               avg                                         (str)
reward_head.bins:                              255                                         (int)
cont_head.layers:                              2                                           (int)
cont_head.units:                               512                                         (int)
cont_head.act:                                 silu                                        (str)
cont_head.norm:                                layer                                       (str)
cont_head.dist:                                binary                                      (str)
cont_head.outscale:                            1.0                                         (float)
cont_head.outnorm:                             False                                       (bool)
cont_head.inputs:                              [deter, stoch]                              (strs)
cont_head.winit:                               normal                                      (str)
cont_head.fan:                                 avg                                         (str)
loss_scales.image:                             1.0                                         (float)
loss_scales.vector:                            1.0                                         (float)
loss_scales.reward:                            1.0                                         (float)
loss_scales.cont:                              1.0                                         (float)
loss_scales.dyn:                               0.5                                         (float)
loss_scales.rep:                               0.1                                         (float)
loss_scales.actor:                             1.0                                         (float)
loss_scales.critic:                            1.0                                         (float)
loss_scales.slowreg:                           1.0                                         (float)
dyn_loss.impl:                                 kl                                          (str)
dyn_loss.free:                                 1.0                                         (float)
rep_loss.impl:                                 kl                                          (str)
rep_loss.free:                                 1.0                                         (float)
model_opt.opt:                                 adam                                        (str)
model_opt.lr:                                  0.0001                                      (float)
model_opt.eps:                                 1e-08                                       (float)
model_opt.clip:                                1000.0                                      (float)
model_opt.wd:                                  0.0                                         (float)
model_opt.warmup:                              0                                           (int)
model_opt.lateclip:                            0.0                                         (float)
actor.layers:                                  2                                           (int)
actor.units:                                   512                                         (int)
actor.act:                                     silu                                        (str)
actor.norm:                                    layer                                       (str)
actor.minstd:                                  0.1                                         (float)
actor.maxstd:                                  1.0                                         (float)
actor.outscale:                                1.0                                         (float)
actor.outnorm:                                 False                                       (bool)
actor.unimix:                                  0.01                                        (float)
actor.inputs:                                  [deter, stoch]                              (strs)
actor.winit:                                   normal                                      (str)
actor.fan:                                     avg                                         (str)
actor.symlog_inputs:                           False                                       (bool)
critic.layers:                                 2                                           (int)
critic.units:                                  512                                         (int)
critic.act:                                    silu                                        (str)
critic.norm:                                   layer                                       (str)
critic.dist:                                   symlog_disc                                 (str)
critic.outscale:                               0.0                                         (float)
critic.outnorm:                                False                                       (bool)
critic.inputs:                                 [deter, stoch]                              (strs)
critic.winit:                                  normal                                      (str)
critic.fan:                                    avg                                         (str)
critic.bins:                                   255                                         (int)
critic.symlog_inputs:                          False                                       (bool)
actor_opt.opt:                                 adam                                        (str)
actor_opt.lr:                                  3e-05                                       (float)
actor_opt.eps:                                 1e-05                                       (float)
actor_opt.clip:                                100.0                                       (float)
actor_opt.wd:                                  0.0                                         (float)
actor_opt.warmup:                              0                                           (int)
actor_opt.lateclip:                            0.0                                         (float)
critic_opt.opt:                                adam                                        (str)
critic_opt.lr:                                 3e-05                                       (float)
critic_opt.eps:                                1e-05                                       (float)
critic_opt.clip:                               100.0                                       (float)
critic_opt.wd:                                 0.0                                         (float)
critic_opt.warmup:                             0                                           (int)
critic_opt.lateclip:                           0.0                                         (float)
actor_dist_disc:                               onehot                                      (str)
actor_dist_cont:                               normal                                      (str)
actor_grad_disc:                               reinforce                                   (str)
actor_grad_cont:                               backprop                                    (str)
critic_type:                                   vfunction                                   (str)
imag_horizon:                                  15                                          (int)
imag_unroll:                                   False                                       (bool)
horizon:                                       333                                         (int)
return_lambda:                                 0.95                                        (float)
critic_slowreg:                                logprob                                     (str)
slow_critic_update:                            1                                           (int)
slow_critic_fraction:                          0.02                                        (float)
retnorm.impl:                                  perc_ema                                    (str)
retnorm.decay:                                 0.99                                        (float)
retnorm.max:                                   1.0                                         (float)
retnorm.perclo:                                5.0                                         (float)
retnorm.perchi:                                95.0                                        (float)
actent:                                        0.0003                                      (float)
expl_rewards.extr:                             1.0                                         (float)
expl_rewards.disag:                            0.1                                         (float)
expl_opt.opt:                                  adam                                        (str)
expl_opt.lr:                                   0.0001                                      (float)
expl_opt.eps:                                  1e-05                                       (float)
expl_opt.clip:                                 100.0                                       (float)
expl_opt.wd:                                   0.0                                         (float)
expl_opt.warmup:                               0                                           (int)
disag_head.layers:                             2                                           (int)
disag_head.units:                              512                                         (int)
disag_head.act:                                silu                                        (str)
disag_head.norm:                               layer                                       (str)
disag_head.dist:                               mse                                         (str)
disag_head.outscale:                           1.0                                         (float)
disag_head.inputs:                             [deter, stoch, action]                      (strs)
disag_head.winit:                              normal                                      (str)
disag_head.fan:                                avg                                         (str)
disag_target:                                  [stoch]                                     (strs)
disag_models:                                  8                                           (int)
Encoder CNN shapes: {'image': (256, 256, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (256, 256, 3)}
Decoder MLP shapes: {}
JAX devices (1): [gpu(id=0)]
Policy devices: gpu:0
Train devices:  gpu:0
Tracing train function.
Optimizer model_opt has 61,839,491 variables.
Optimizer actor_opt has 1,051,650 variables.
Optimizer critic_opt has 1,181,439 variables.
Logdir logs/test
Observation space:
  reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
  is_first         Space(dtype=bool, shape=(), low=False, high=True)
  is_last          Space(dtype=bool, shape=(), low=False, high=True)
  is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
  image            Space(dtype=uint8, shape=(256, 256, 3), low=0, high=255)
Action space:
  reset            Space(dtype=bool, shape=(), low=False, high=True)
  action           Space(dtype=float32, shape=(1,), low=-1.0, high=1.0)
Prefill train dataset.
[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /tmp/tmpfkaj__gy.
[reverb/cc/platform/tfrecord_checkpointer.cc:567] Loading latest checkpoint from /tmp/tmpfkaj__gy
[reverb/cc/platform/default/server.cc:71] Started replay server on port 15055
Prefill eval dataset.
Found existing checkpoint.
Loading checkpoint: logs/test/checkpoint.ckpt
[reverb/cc/client.cc:[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.                                                                                                                         
165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (2594171) so Table table is accessed directly without gRPC.
Loaded checkpoint from 967 seconds ago.
Start training loop.
Starting evaluation at step 1560
Tracing policy function.
Tracing policy function.
Episode has 500 steps with return 161.2.
Episode has 500 steps with return 95.8.
Episode has 500 steps with return 66.3.
Episode has 500 steps with return 98.4.
Episode has 500 steps with return 117.0.
Episode has 500 steps with return 111.9.
Episode has 500 steps with return 102.0.
Episode has 500 steps with return 56.4.
Episode has 500 steps with return 88.9.
Episode has 500 steps with return 91.3.
Episode has 500 steps with return 117.6.
Episode has 500 steps with return 89.3.
Episode has 500 steps with return 158.3.
Episode has 500 steps with return 81.8.
Episode has 500 steps with return 43.6.
Episode has 500 steps with return 86.3.
Tracing policy function.
Tracing train function.

I have tested this on a V100 and an A100. Both with the same result. With smaller resolutions (e.g. 128x128 or 64x64) this works of course.

I tried to debug this but I am not really able to track this down inside Ninjax or Jax.

Thanks a lot for your help!

@edwhu
Copy link

edwhu commented Sep 15, 2023

Sometimes the trace can take a while with old GPUs, I've waited around 10 minutes for a TitanX workstation before.

You can try making the CNN smaller to see if that speeds up compilation time. You can also try incrementally increasing the resolution and check if the trace time increases.

@schneimo
Copy link
Author

Thanks.

I am not sure if time and compute power is really the problem. Even after 24 hours, it did not trace on an A100. But I will test how tracing time increases with increasing image resolution and report my findings here.

@schneimo
Copy link
Author

I worked a little bit more on this topic and found out that the train function of class Agent is called completely since when it is decorated with an additional timer, the timer gets executed.

Furthermore, I tracked the problem a little bit more down and it seems to arise in the try block of pure inside the Ninjax module.

def pure(fun, nested=False):
"""Wrap an impure function that uses global state to explicitly pass the
state in and out. The result is a pure function that is composable with JAX
transformation. The pure function can be used as follows:
`out, state = fun(state, rng, *args, **kwargs)`."""
def purified(
state, rng, *args, create=None, modify=None, ignore=None, **kwargs):
context = CONTEXT.get(threading.get_ident(), None)
if context:
create = create if create is not None else context.create
modify = modify if modify is not None else context.modify
ignore = ignore if ignore is not None else context.ignore
assert context.create or not create, 'Parent context disabled create.'
assert context.modify or not modify, 'Parent context disabled modify.'
assert not context.ignore or ignore, 'Parent context enabled ignore.'
else:
create = create if create is not None else True
modify = modify if modify is not None else True
ignore = ignore if ignore is not None else False
if not isinstance(state, dict):
raise ValueError('Must provide a dict as state.')
if context and (not nested):
raise RuntimeError(
f'You are trying to call pure {fun.__name__}() inside pure '
f'{context.name}(). Is that intentional? If you want to nest pure '
f'functions, use pure(..., nested=True) for the inner function.')
# raise RuntimeError(
# f'If you want to nest run() calls, use nested=True. ({context})')
before = context
try:
name = fun.__name__
if rng.shape == ():
rng = jax.random.PRNGKey(rng)
context = Context(state.copy(), rng, create, modify, ignore, [], name)
CONTEXT[threading.get_ident()] = context
out = fun(*args, **kwargs)
state = dict(context)
return out, state
finally:
CONTEXT[threading.get_ident()] = before
purified.pure = True
return purified

@schneimo schneimo changed the title 'Tracing train function' gets stuck if using larger image sizes Calling JAXAgent train gets stuck if using larger image sizes (inside Ninjax) Oct 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants