Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN problems in mjx speed benchmarks on cpu #1616

Open
fabinsch opened this issue Apr 23, 2024 · 0 comments
Open

NaN problems in mjx speed benchmarks on cpu #1616

fabinsch opened this issue Apr 23, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@fabinsch
Copy link

fabinsch commented Apr 23, 2024

Hello guys and thanks for maintaining this library. I tried to run the testspeed.py script and encounter nans.

I am on Ubuntu 22.04 and I create a clean env and installed mujoco-mjx by

conda create --name mjx_test python=3.10
pip install mujoco-mjx

then I launch the benchmarks by

python testspeed.py --mjcf humanoid/humanoid.xml

and I see the following

(mjx_test) ➜  mjx python testspeed.py --mjcf humanoid/humanoid.xml
Rolling out 1000 steps at dt = 0.005...
I0423 16:03:29.898261 140143597762368 xla_bridge.py:863] Unable to initialize backend 'cuda': 
I0423 16:03:29.898374 140143597762368 xla_bridge.py:863] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0423 16:03:29.898871 140143597762368 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0423 16:03:29.898980 140143597762368 xla_bridge.py:901] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
2024-04-23 16:03:53.934908: W external/xla/xla/service/cpu/onednn_matmul.cc:293] [Perf]: MatMul reference implementation being executed
2024-04-23 16:03:53.993505: W external/xla/xla/service/cpu/onednn_matmul.cc:293] [Perf]: MatMul reference implementation being executed
....
2024-04-23 16:04:45.988970: W external/xla/xla/service/cpu/onednn_matmul.cc:293] [Perf]: MatMul reference implementation being executed
result.qpos: [[[nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  ...
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]]]

Summary for 1024 parallel rollouts

 Total JIT time: 23.58 s
 Total simulation time: 52.14 s
 Total steps per second: 19640
 Total realtime factor: 98.20 x
 Total time per step: 50.92 µs

After having added a line print(f"result.qpos: {result.qpos}") here.
The output of pip list is

(mjx_test) ➜  mjx pip list
Package             Version
------------------- --------
absl-py             2.1.0
etils               1.7.0
fsspec              2024.3.1
glfw                2.7.0
importlib_resources 6.4.0
jax                 0.4.26
jaxlib              0.4.26
ml-dtypes           0.4.0
mujoco              3.1.4
mujoco-mjx          3.1.4
numpy               1.26.4
opt-einsum          3.3.0
pip                 24.0
PyOpenGL            3.1.7
scipy               1.13.0
setuptools          69.5.1
trimesh             4.3.1
typing_extensions   4.11.0
wheel               0.43.0
zipp                3.18.1

I came across an issue mentioning the MatMul reference implementation being executed here. I tried to run the benchmarks on a mac m1 and I do not have the MatMul reference or nan issue anymore.

If you have any idea what is going on for mjx on ubuntu (CPU) I would be happy,
Thanks for your time

EDIT:
I have also checked other issues here and trying to increase the precision via jax.config.update("jax_enable_x64", True) does not work

(mjx_test) ➜  mjx python testspeed.py --mjcf humanoid/humanoid.xml
Rolling out 1000 steps at dt = 0.005...
I0423 17:15:45.344351 139876225578816 xla_bridge.py:863] Unable to initialize backend 'cuda': 
I0423 17:15:45.344468 139876225578816 xla_bridge.py:863] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0423 17:15:45.344967 139876225578816 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0423 17:15:45.345096 139876225578816 xla_bridge.py:901] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 88, in <module>
    main()
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 84, in main
    app.run(_main)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 58, in _main
    jit_time, run_time, steps = mjx.benchmark(
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 106, in benchmark
    jit_time, run_time = _measure(unroll, d)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 41, in _measure
    compiled_fn = fn.lower(*args).compile()
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 102, in unroll
    d, _ = jax.lax.scan(step, d, None, length=nstep, unroll=unroll_steps)
TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
  * the input carry component d.contact.geom1 has type int64[1024,8] but the corresponding output carry component has type int32[1024,8], so the dtypes do not match

  * the input carry component d.contact.geom2 has type int64[1024,8] but the corresponding output carry component has type int32[1024,8], so the dtypes do not match

Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

as well as config.update("jax_debug_nans", True) did not give me any useful information

Traceback (most recent call last):
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 85, in <module>
    main()
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 81, in main
    app.run(_main)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 55, in _main
    jit_time, run_time, steps = mjx.benchmark(
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 104, in benchmark
    jit_time, run_time = _measure(unroll, d)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 44, in _measure
    result = compiled_fn(*args)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/stages.py", line 594, in __call__
    return self._call(*args, **kwargs)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/stages.py", line 591, in cpp_call_fallback
    outs, _, _ = Compiled.call(params, *args, **kwargs)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/stages.py", line 563, in call
    out_flat = params.executable.call(*args_flat)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1089, in call
    return self.unsafe_call(*args)  # pylint: disable=not-callable
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1217, in __call__
    dispatch.check_special(self.name, arrays)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/dispatch.py", line 314, in check_special
    _check_special(name, buf.dtype, buf)
  File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/dispatch.py", line 319, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in parallel computation
@fabinsch fabinsch added the bug Something isn't working label Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants