-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Control number of vmapped envs in evaluator using arch.num_envs
#1071
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @OmaymaMahjoub for making things more flexible 🙏
n_devices = len(jax.devices()) | ||
episodes_per_device = config.arch.num_eval_episodes * eval_multiplier // n_devices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly jax doesn't allow us to carry these fixed values (parallel_eval_batch_size
) 😢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am quite sure there is a better way to implement this, but we can keep it as it is at the moment and create an issue calling for cleaning and readibility of the evaluator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two small things, one would take a while though so if there's no time we can wait till we refactor the evaluator
mava/evaluator.py
Outdated
parallel_eval_batch_size = min(config.arch.num_envs, episodes_per_device) | ||
# Compute the number of sequential evaluation batches required per device | ||
# to cover all episodes. | ||
sequential_eval_batches = episodes_per_device // parallel_eval_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you make keys of shape (sequential_eval_batches, num_vmapped_episodes)
then you don't need the extra method, you can just call:
jax.lax.scan(jax.vmap(eval_one_episode), None, eval_states)
I like this for two reasons, you don't have to calculate parallel_eval_batch_size
twice and it is also clear that you're scanning and then vmapping over episodes. But this is a big change so if you don't have time it's ok. I think the evaluator is due for a big overhaul anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I back this idea also 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see so even the eval_init will be created inside the one episode, I will give it a try
Co-authored-by: Sasha <reallysasha@gmail.com>
What?
Modify the evaluator to limit the number of vmapped envs to
arch.num_envs
when the total number of evaluation episodesarch.num_eval_episodes
exceeds this limit (instead of parallelise all thearch.num_eval_episodes
). In such cases, evaluations are conducted in sequential batches, with each batch containingarch.num_envs
parallel envs.Why?
Limiting parallel evaluations to num_envs prevents out-of-memory issues by avoiding vmap over all episodes at once.