Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing and less regular metric collection #13

Open
jheagerty opened this issue Nov 14, 2023 · 11 comments
Open

Checkpointing and less regular metric collection #13

jheagerty opened this issue Nov 14, 2023 · 11 comments

Comments

@jheagerty
Copy link

jheagerty commented Nov 14, 2023

I know this sound ridiculous but I've spent ages trying to implement checkpoint saving into your example/walkthrough training code and have been getting nowhere.

Similarly (as it's something to do every n steps or every epoch) I've been trying to reduce the frequency of metric collection, as it has been giving me VRAM errors with my lowly NVIDIA 3080.

Any advice / solutions would be very gratefully received.

@jheagerty
Copy link
Author

And sorry, for context, the checkpointing is so that I can implement self-play where the baseline model that controls the enemy is updated to match the model we're training every once in a while.

@jheagerty
Copy link
Author

jheagerty commented Nov 17, 2023

Never mind, figured out the main thing for me, checkpointing. You have to:

  • remove the jitting of make_train when you call it
  • @jax.jit the def _update_step
  • (these two combined seems to make no measurable difference to performance, but my testing was not extensive)
  • add a save of whatever you want at the end of train(rng)
  • add a load from your save of specifically only network_params at the start

My one concern is that I doubt/cannot tell whether the learning rate scheduler is maintained, but I will worry about that later.

On less regular metrics, I will also worry about that later, but I've seen some likely jittable tools.

@Howuhh
Copy link

Howuhh commented Nov 23, 2023

@jheagerty actually I think you can save checkpoints under jit easily with callbacks, such as jax.experimental.io_callback() (for example inside _update_step to save after the each update)

@luchris429
Copy link
Owner

yep! I do it the way @Howuhh is describing. If you look at the code, there is the debug callback. You can just replace the print function with your checkpointing and wandb logging.

@jheagerty
Copy link
Author

Thanks so much! Will look into this

@luchris429
Copy link
Owner

The way you do it should be fine too, and is arguably better (though takes more code). The optimizer parameters (which includes the lr scheduling info) should be in the train_state

@Chulabhaya
Copy link

@luchris429 Hi Chris! I was wondering, how did you implement restoring of checkpoints in the PureJaxRL end-to-end jitting? I'm able to save checkpoints pretty easily with a debug callback function, but I can't quite figure out how to restore. I attempted to put a experimental.io_callback function in the train function but I can't actually do anything with the string checkpoint path because JAX can't handle strings.

@luchris429
Copy link
Owner

You can try to load the runner state here!

Does it not work if you set the filename in the config?

@Chulabhaya
Copy link

Chulabhaya commented Mar 5, 2024

You can try to load the runner state here!

Does it not work if you set the filename in the config?

So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):

def resuming_callback(path):
    checkpointer = ocp.PyTreeCheckpointer()
    raw_restored = checkpointer.restore(path)
    return raw_restored

runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
    raw_restored = io_callback(
        resuming_callback, runner_state, args.resume_checkpoint_path
     )

runner_state, metric = jax.lax.scan(
    _update_step, runner_state, None, args.num_iterations
)

However JAX errors out with the complaint that my args.resume_checkpoint_path is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?

@luchris429
Copy link
Owner

Sorry I didn't catch this message! I hope you've figured it out.

I think you need to make sure it's a static argument since you can't JIT a string as an argument.

@gzadigo
Copy link

gzadigo commented Apr 21, 2024

You can try to load the runner state here!
Does it not work if you set the filename in the config?

So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):

def resuming_callback(path):
    checkpointer = ocp.PyTreeCheckpointer()
    raw_restored = checkpointer.restore(path)
    return raw_restored

runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
    raw_restored = io_callback(
        resuming_callback, runner_state, args.resume_checkpoint_path
     )

runner_state, metric = jax.lax.scan(
    _update_step, runner_state, None, args.num_iterations
)

However JAX errors out with the complaint that my args.resume_checkpoint_path is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?

actually, I have tried with failure result even a fixed filename. Restoration in jit is quite difficult for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants