Question: Is there a way to vmap models + data over a simulation loop. #1411
Replies: 1 comment 4 replies
-
Hi @oliverweissl ! I encourage you to look at the colab tutorial. The "Introduction to MJX" section has an example of batching environments (look for For the mujoco/mjx/mujoco/mjx/testspeed.py Lines 94 to 104 in e77c3cb Another thing to note is that |
Beta Was this translation helpful? Give feedback.
-
Hi,
I'm a student and I'm trying to use MuJoCo for simulating modular robots.
I'm looking for some help with parallelizing these simulations in MJX.
My aim is to utilize GPU devices for parallelizing simulations, which are composed of n
mjx.Model
and nmjx.Data
.In the vanilla implementation we use multiprocessing to spawn processes on the CPU, that take one model and data object and use them in a while loop (including the step function) for the main simulation.
The Vanilla implementation:
My idea to use MJX was to do a pmap over the main simulation loop like this:
However this fails because the lists i want to map over are pytrees. I have read that batching could be one way to make this work, but I could not find any documentation on this.
Do you have a tip on how to approach this? Also this is my first time using JAX so im still learning and trying to scan through the docs to find something that fits my issue.
Thanks in advance for any suggestions and help!
Beta Was this translation helpful? Give feedback.
All reactions