Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Control number of vmapped envs in evaluator using arch.num_envs #1071

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

OmaymaMahjoub
Copy link
Contributor

What?

Modify the evaluator to limit the number of vmapped envs to arch.num_envs when the total number of evaluation episodes arch.num_eval_episodes exceeds this limit (instead of parallelise all the arch.num_eval_episodes). In such cases, evaluations are conducted in sequential batches, with each batch containing arch.num_envs parallel envs.

Why?

Limiting parallel evaluations to num_envs prevents out-of-memory issues by avoiding vmap over all episodes at once.

WiemKhlifi
WiemKhlifi previously approved these changes Mar 26, 2024
Copy link
Contributor

@WiemKhlifi WiemKhlifi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @OmaymaMahjoub for making things more flexible 🙏

n_devices = len(jax.devices())
episodes_per_device = config.arch.num_eval_episodes * eval_multiplier // n_devices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly jax doesn't allow us to carry these fixed values (parallel_eval_batch_size) 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite sure there is a better way to implement this, but we can keep it as it is at the moment and create an issue calling for cleaning and readibility of the evaluator

Copy link
Contributor

@sash-a sash-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small things, one would take a while though so if there's no time we can wait till we refactor the evaluator

mava/evaluator.py Outdated Show resolved Hide resolved
parallel_eval_batch_size = min(config.arch.num_envs, episodes_per_device)
# Compute the number of sequential evaluation batches required per device
# to cover all episodes.
sequential_eval_batches = episodes_per_device // parallel_eval_batch_size
Copy link
Contributor

@sash-a sash-a Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make keys of shape (sequential_eval_batches, num_vmapped_episodes) then you don't need the extra method, you can just call:

jax.lax.scan(jax.vmap(eval_one_episode), None, eval_states)

I like this for two reasons, you don't have to calculate parallel_eval_batch_size twice and it is also clear that you're scanning and then vmapping over episodes. But this is a big change so if you don't have time it's ok. I think the evaluator is due for a big overhaul anyways

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I back this idea also 🔥

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see so even the eval_init will be created inside the one episode, I will give it a try

Co-authored-by: Sasha <reallysasha@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants