Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi4jax API version mismatch #32

Open
coreyjadams opened this issue Apr 29, 2024 · 3 comments
Open

mpi4jax API version mismatch #32

coreyjadams opened this issue Apr 29, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@coreyjadams
Copy link

With release 0.3.0, I am unable to get mpi4jax to run. I am using this branch from an Intel-forked mpi4jax: https://github.com/jczaja/mpi4jax/tree/jczaja/xpu-support. This is running on Argonne's Sunspot cluster with Intel Max 1550 gpus.

I have installed intel_extension_for_open_xla with version 0.3.0 via pip. I have oneapi 2024.1 and agam 803.29. Here is what I see when I import jax, then import mpi4jax:

>>> import jax
jax.local_devices()
>>> jax.local_devices()
INFO: Intel Extension for OpenXLA version: 0.3.0, commit: 9a484818
Platform 'xpu' is experimental and not all JAX functionality may be correctly supported!
[xpu(id=0), xpu(id=1), xpu(id=2), xpu(id=3), xpu(id=4), xpu(id=5), xpu(id=6), xpu(id=7), xpu(id=8), xpu(id=9), xpu(id=10), xpu(id=11)]
>>> import mpi4jax
Registering b'mpi_allgather' and function <capsule object "xla._CUSTOM_CALL_TARGET" at 0x1458e2532ca0>
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/__init__.py", line 9, in <module>
    from ._src import (  # noqa: E402
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/__init__.py", line 11, in <module>
    from . import xla_bridge  # noqa: E402
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/xla_bridge/__init__.py", line 42, in <module>
    xla_client.register_custom_call_target(name, fn, platform="SYCL")
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/jaxlib/xla_client.py", line 588, in register_custom_call_target
    _custom_callback_handler[xla_platform_name](name, fn, xla_platform_name)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: API version 1986225522 not supported for PJRT GPU plugin. Supported versions are 0 and 1.
>>> 

Do I need to target a specific api version in mpi4jax to make this work? Or, do I need to build JAX from source?

Thanks!
Corey

@Zantares
Copy link
Contributor

This is an API mismatch issue after upgrading JAX. @Dboyqiao has fixed this issue and will give the solution.

@Zantares Zantares added the bug Something isn't working label May 14, 2024
@Dboyqiao
Copy link
Contributor

@coreyjadams We have fixed this issue on main branch (compatible with Jax 0.4.25), you need to build it manually before next release. Please refer to install-from-source-build for detail about source build.
Besides, mpi4jax for xpu support has been merged, so you can use public repo: https://github.com/mpi4jax/mpi4jax.git directly now.

@Dboyqiao
Copy link
Contributor

@coreyjadams As I know, the scale out with mpi4jax is still blocked by a JAX bug, which will be fixed by v0.4.28. Since intel-extension-for-openxla will not be rebased to align with JAX v0.4.28 soon, could you provide more detail about the JAX bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants