Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: array types: add JAX support #20085

Merged
merged 73 commits into from
May 18, 2024
Merged

ENH: array types: add JAX support #20085

merged 73 commits into from
May 18, 2024

Conversation

lucascolley
Copy link
Member

@lucascolley lucascolley commented Feb 13, 2024

Reference issue

Towards gh-18867

What does this implement/fix?

First steps on JAX support. To-do:

Additional information

Can do the same for dask.array once the problems are fixed over at data-apis/array-api-compat#89.

@github-actions github-actions bot added scipy.cluster scipy._lib Meson Items related to the introduction of Meson as the new build system for SciPy array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement labels Feb 13, 2024
@lucascolley lucascolley removed the Meson Items related to the introduction of Meson as the new build system for SciPy label Feb 13, 2024
@lucascolley
Copy link
Member Author

lucascolley commented Feb 14, 2024

(The reason I decided to comment over on the DLPack issue is that I recall a conversation about how portability could be increased if we replace occurrences of np.asarray with {array_api_compat.numpy, np>=2.0}.from_dlpack}. Clearly, portability past libraries which are coercible by np.asarray is very low prio at the minute, but something to consider long-term. Also, DLPack being the idiomatic way to do library-interchange, rather than relying on the array-creation function asarray)

@rgommers
Copy link
Member

Thanks for working on this Lucas. JAX support will be very nice. And a third library with CPU support (after NumPy and PyTorch) will also be good for testing how generic our array API standard support actually is.

Okay, related to the read-only question, it looks like this is the problem you were seeing:

scipy/cluster/hierarchy.py:1038: in linkage
    result = _hierarchy.mst_single_linkage(y, n)
        method     = 'single'
        method_code = 0
        metric     = 'euclidean'
        n          = 6
        optimal_ordering = False
        xp         = <module 'jax.experimental.array_api' from '/home/rgommers/mambaforge/envs/scipy-dev-jax/lib/python3.11/site-packages/jax/experimental/array_api/__init__.py'>
        y          = array([1.48660687, 2.23606798, 1.41421356, 1.41421356, 1.41421356,
       2.28254244, 0.1       , 1.48660687, 1.48660687, 2.23606798,
       1.        , 1.        , 1.41421356, 1.41421356, 0.        ])
_hierarchy.pyx:1015: in scipy.cluster._hierarchy.mst_single_linkage
    ???
<stringsource>:663: in View.MemoryView.memoryview_cwrapper
    ???
<stringsource>:353: in View.MemoryView.memoryview.__cinit__
    ???
E   ValueError: buffer source array is read-only

The problem is that Cython doesn't accept read-only arrays when the signature is a regular memoryview. There's a long discussion about this topic in scikit-learn/scikit-learn#10624. Now that we have Cython 3 though, the fix is simple:

diff --git a/scipy/cluster/_hierarchy.pyx b/scipy/cluster/_hierarchy.pyx
index 814051df2..c59b3de6a 100644
--- a/scipy/cluster/_hierarchy.pyx
+++ b/scipy/cluster/_hierarchy.pyx
@@ -1012,7 +1012,7 @@ def nn_chain(double[:] dists, int n, int method):
     return Z_arr
 
 
-def mst_single_linkage(double[:] dists, int n):
+def mst_single_linkage(const double[:] dists, int n):
     """Perform hierarchy clustering using MST algorithm for single linkage.
 
     Parameters

This makes the tests pass (at least for this issue, I tried with the dendrogram tests only). The dists input to mst_single_linkage isn't modified in-place, so once we tell Cython that by adding const, things are happy.

@lucascolley lucascolley added the Cython Issues with the internal Cython code base label Feb 14, 2024
@lucascolley
Copy link
Member Author

thanks! I've removed the copies and added some consts to the Cython file to get the tests to pass. Still some failures for in-place assignments with indexing but we can circle back to those once we get integration with the test skip infra.

scipy/conftest.py Outdated Show resolved Hide resolved
Copy link
Member

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit, because it requires converting array objects to host-side buffers, and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?

@rgommers
Copy link
Member

One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit, because it requires converting array objects to host-side buffers,

Do you mean for testing purposes, or for library code? For the latter: we should never do device transfers like GPU->host memory under the hood. The array API standard design was careful to not include that. It wasn't even possible at all until very recently, when a way was added to do it with DLPack (for testing purposes).

If you mean "convert to numpy.ndarray before going into Cython/C/C++/Fortran code inside SciPy, then yes that is happening. That's kinda not an array API standard issue, because it's leaving Python - and that's a very different problem. To avoid compiled code inside SciPy - which indeed won't work with any JIT compiler unless that JIT is specifically aware of the SciPy functionality being called - it'd be necessary to have either a pure Python path (slow) or a matching API inside JAX that can be called (jax.scipy has some that we should be deferring to here).

and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?

JIT compilers were explicitly considered, and nothing in the standard should be JIT-unfriendly, except for the few clearly marked as data-dependent output shapes and the few dunder methods that are also problematic for lazy arrays.

@lucascolley
Copy link
Member Author

lucascolley commented Feb 27, 2024

Do you mean for testing purposes, or for library code? For the latter

If this is what you meant, x-ref the 'Dispatching Mechanism' section of gh-18286

@jakevdp
Copy link
Member

jakevdp commented Feb 27, 2024

I mean for actual user-level code: most of the work here will be more-or-less useless for JAX users because array conversions via dlpack cannot be done under JIT without some sort of callback mechanism.

@rgommers
Copy link
Member

Okay, I had a look at https://jax.readthedocs.io/en/latest/tutorials/external-callbacks.html and understand what you mean now. jax.pure_callback looks quite interesting indeed. I wasn't familiar with it, but it looks like that may actually solve an important puzzle in dealing with compiled code. It doesn't support GPU execution or auto-differentiation, but getting jax.jit and jax.vmap to work would be a significant step forward.

It looks fairly straightforward to support (disclaimer: I haven't tried it yet). It'd be taking this current code pattern:

# inside some Python-level scipy function with array API standard support:

x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result)  # back to original array type

and replacing it with something like (untested):

def call_compiled_code_helper(x, xp):  # needs *args, *kwargs too
    if is_jax(x):
        result_shape_dtypes = ... # TODO: figure out how to construct the needed PyTree here
        result = jax.pure_callback(call_some_compiled_code, result_shape_dtypes, x)
    else:
        x = np.asarray(x)
        result = call_some_compiled_code(x)
        result = xp.asarray(result)

Use of a utility function like call_compiled_code_helper may even make the code shorter and easier to understand. It seems feasible at first sight.

It's interesting that jax.pure_callback transforms JAX arrays to NumPy arrays under the hood already.

@jakevdp
Copy link
Member

jakevdp commented Feb 27, 2024

Yeah, something like that is what I had in mind, though pure_callback is probably not the right mechanism. JAX doesn't currently have an easy pure-callback-like mechanism for executing custom kernels on device, without the round-trip to host implied by pure_callback. I wonder if this kind of thing will be an issue for other array API libraries?

@rgommers
Copy link
Member

I wonder if this kind of thing will be an issue for other array API libraries?

It is (depending on your defintion of "issue") because there's no magic bullet that will do something like take some native function implemented in C/Fortran/Cython inside SciPy and make that run on GPU.

The basic state of things is:

  • functions implemented in pure Python are unproblematic, and with array API support get to run on GPU/TPU, gain autograd support, etc.
    • with a few exceptions: functions using unique and other data-dependent shapes, iterative algorithms with a stopping/branching criterion that requires eager evaluation, functions using in-place operations.
  • as soon as you hit compiled code, things get harder. everything that worked before with numpy only will still work, but autograd and GPU execution won't

JAX doesn't currently have an easy pure-callback-like mechanism for executing custom kernels on device, without the round-trip to host implied by pure_callback.

In a generic library like SciPy it's almost impossible to support custom kernels on device. Our choices for arrays that don't live on host memory are:

  • find a matching function in the other library. e.g., we can explicitly defer to everything in jax.scipy, cupyx.scipy and torch.fft/linalg/special,
  • raise an exception
  • do an automatic to/from host roundtrip (we haven't considered this a good idea before, since data transfers can be very expensive - but apparently that's what pure_callback prefers over raising)

@lucascolley
Copy link
Member Author

lucascolley commented Feb 28, 2024

I gave adding Dask another shot just now, but unfortunately things are missing from dask.array like float64, which makes most of our test code fail. Perhaps we will have to change to using the wrapped namespaces throughout the tests (this is awkward because we still need to imitate an array from the unwrapped namespace being input).

x-ref dask/dask#10387 (comment)

@rgommers
Copy link
Member

I'd suggest keeping this PR focused on JAX and getting that merged first. That makes it easier to see (also in the future) what had to be done only for JAX. And if we're going to experiment a bit with jax.jit, this PR may grow already.

@rgommers
Copy link
Member

One more in fft - not sure why it showed up only now:

diff --git a/scipy/fft/tests/test_fftlog.py b/scipy/fft/tests/test_fftlog.py
index d9652facb..146ee4588 100644
--- a/scipy/fft/tests/test_fftlog.py
+++ b/scipy/fft/tests/test_fftlog.py
@@ -104,7 +104,7 @@ def test_fht_identity(n, bias, offset, optimal, xp):
     A = fht(a, dln, mu, offset=offset, bias=bias)
     a_ = ifht(A, dln, mu, offset=offset, bias=bias)
 
-    xp_assert_close(a_, a)
+    xp_assert_close(a_, a, rtol=1.5e-7)
 
 
 def test_fht_special_cases(xp):

Failures:

_______________________________________ test_fht_identity[63--0.1--1.0-True-jax.numpy] ________________________________________
scipy/fft/tests/test_fftlog.py:107: in test_fht_identity
    xp_assert_close(a_, a)
        A          = Array([ 0.2895463 , -0.57715851,  0.04246836,  1.68784873,  0.82342754,
        1.61900088, -0.25707367,  1.18689669, ...4551,  0.17525223, -0.27103611, -0.61543826, -0.03412202,
        0.52722178, -0.19101081,  0.03544419], dtype=float64)
        a          = Array([-1.12884913e+00,  4.18400908e-01, -1.41696768e+00,  1.48817091e+00,
       -1.77498877e+00,  2.36438510e+00, -6...034215e-02,  7.61753132e-01,  1.64643447e-01,
       -7.13788349e-01, -2.35421899e+00,  2.78841166e-03], dtype=float64)
        a_         = Array([-1.12884926e+00,  4.18400957e-01, -1.41696784e+00,  1.48817108e+00,
       -1.77498897e+00,  2.36438537e+00, -6...034309e-02,  7.61753220e-01,  1.64643466e-01,
       -7.13788431e-01, -2.35421927e+00,  2.78841198e-03], dtype=float64)
        bias       = -0.1
        dln        = -0.055849123963318315
        mu         = -0.8303697203418872
        n          = 63
        offset     = -1.0145511772158133
        optimal    = True
        rng        = RandomState(MT19937) at 0x70A96AFDD940
        xp         = <module 'jax.numpy' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/jax/numpy/__init__.py'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=5.96046e-08, atol=0
E   
E   Mismatched elements: 63 / 63 (100%)
E   Max absolute difference: 3.91288673e-07
E   Max relative difference: 1.15484016e-07
E    x: array([-1.128849e+00,  4.184010e-01, -1.416968e+00,  1.488171e+00,
E          -1.774989e+00,  2.364385e+00, -6.605323e-01,  2.006572e+00,
E           9.592554e-01,  1.139476e+00, -1.515723e+00,  7.299143e-01,...
E    y: array([-1.128849e+00,  4.184009e-01, -1.416968e+00,  1.488171e+00,
E          -1.774989e+00,  2.364385e+00, -6.605323e-01,  2.006572e+00,
E           9.592553e-01,  1.139476e+00, -1.515723e+00,  7.299142e-01,...
        args       = (<function assert_allclose.<locals>.compare at 0x70a96afd9d00>, array([-1.12884926e+00,  4.18400957e-01, -1.41696784e+...184e+00,  8.14034215e-02,  7.61753132e-01,  1.64643447e-01,
       -7.13788349e-01, -2.35421899e+00,  2.78841166e-03]))
        func       = <function assert_array_compare at 0x70a98bc167a0>
        kwds       = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=5.96046e-08, atol=0', 'verbose': True}
        self       = <contextlib._GeneratorContextManager object at 0x70a98bc31730>
_______________________________________ test_fht_identity[63--0.1--1.0-False-jax.numpy] _______________________________________
scipy/fft/tests/test_fftlog.py:107: in test_fht_identity
    xp_assert_close(a_, a)
        A          = Array([ 0.69652742, -0.64694696, -0.14236865,  1.28532617,  1.20979012,
        1.26670293,  0.51802895,  0.18218724, ...2132, -0.41327618,  0.27614662, -0.92507489,  0.01970678,
        0.25598135,  0.24367844, -0.36154562], dtype=float64)
        a          = Array([-1.12884913e+00,  4.18400908e-01, -1.41696768e+00,  1.48817091e+00,
       -1.77498877e+00,  2.36438510e+00, -6...034215e-02,  7.61753132e-01,  1.64643447e-01,
       -7.13788349e-01, -2.35421899e+00,  2.78841166e-03], dtype=float64)
        a_         = Array([-1.12884926e+00,  4.18400957e-01, -1.41696784e+00,  1.48817108e+00,
       -1.77498897e+00,  2.36438537e+00, -6...034309e-02,  7.61753220e-01,  1.64643466e-01,
       -7.13788431e-01, -2.35421927e+00,  2.78841198e-03], dtype=float64)
        bias       = -0.1
        dln        = -0.055849123963318315
        mu         = -0.8303697203418872
        n          = 63
        offset     = -1.0
        optimal    = False
        rng        = RandomState(MT19937) at 0x70A96AFDE140
        xp         = <module 'jax.numpy' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/jax/numpy/__init__.py'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=5.96046e-08, atol=0
E   
E   Mismatched elements: 63 / 63 (100%)
E   Max absolute difference: 3.91288673e-07
E   Max relative difference: 1.15484015e-07
E    x: array([-1.128849e+00,  4.184010e-01, -1.416968e+00,  1.488171e+00,
E          -1.774989e+00,  2.364385e+00, -6.605323e-01,  2.006572e+00,
E           9.592554e-01,  1.139476e+00, -1.515723e+00,  7.299143e-01,...
E    y: array([-1.128849e+00,  4.184009e-01, -1.416968e+00,  1.488171e+00,
E          -1.774989e+00,  2.364385e+00, -6.605323e-01,  2.006572e+00,
E           9.592553e-01,  1.139476e+00, -1.515723e+00,  7.299142e-01,...
        args       = (<function assert_allclose.<locals>.compare at 0x70a96afd93a0>, array([-1.12884926e+00,  4.18400957e-01, -1.41696784e+...184e+00,  8.14034215e-02,  7.61753132e-01,  1.64643447e-01,
       -7.13788349e-01, -2.35421899e+00,  2.78841166e-03]))
        func       = <function assert_array_compare at 0x70a98bc167a0>
        kwds       = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=5.96046e-08, atol=0', 'verbose': True}
        self       = <contextlib._GeneratorContextManager object at 0x70a98bc31730>
=================================================== short test summary info ===================================================
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0-0.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0-0.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0-1.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0-1.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0--1.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0--1.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0.1-0.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0.1-0.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0.1-1.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0.1-1.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0.1--1.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63-0.1--1.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63--0.1-0.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63--0.1-0.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63--0.1-1.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63--0.1-1.0-False-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63--0.1--1.0-True-jax.numpy] - AssertionError: 
FAILED scipy/fft/tests/test_fftlog.py::test_fht_identity[63--0.1--1.0-False-jax.numpy] - AssertionError: 
================================================ 18 failed, 22 passed in 3.10s ==============================

@rgommers
Copy link
Member

Looks like we're basically there. I'll do some testing with PyTorch and CuPy tomorrow to check that we didn't silently broke anything there - and then it's probably good to merge.

Copy link
Contributor

@tylerjereddy tylerjereddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do some testing with PyTorch and CuPy tomorrow to check that we didn't silently broke anything there

Short summary seems to be that most of the failures are already present on main rather than introduced here, from my initial checking on GPU.

For SCIPY_DEVICE=cuda python dev.py test -j 32 -b all with NVIDIA GPU:

  • latest main: 66 failed, 51999 passed, 11312 skipped, 157 xfailed, 13 xpassed in 55.91s
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_fft[torch] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-33 (worker)
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case15-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case16-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case17-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case18-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case19-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_basic[p1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case20-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case21-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ifft[cupy] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case22-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case23-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape0-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case24-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case25-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case26-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape0-cupy] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape1-cupy] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case27-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape2-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case28-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape2-cupy] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case29-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape3-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case30-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_vectorization[shape3-cupy] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case31-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_convergence[torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case32-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case33-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case34-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case35-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case0-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case36-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case2-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case37-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case3-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case4-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case38-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_rfft[cupy] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case5-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case6-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case39-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case7-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case40-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case8-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case41-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case42-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case9-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case43-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case44-torch] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case10-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float16-0.622-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case11-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case12-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case13-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_nit_expected[case14-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float16-root1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.
FAILED scipy/special/tests/test_support_alternative_backends.py::test_support_alternative_backends[f_name_n_args15-array_api_strict] - AssertionError: 
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float64-0.622-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/optimize/tests/test_chandrupatla.py::TestChandrupatla::test_dtype[float64-root1-torch] - RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
FAILED scipy/special/tests/test_support_alternative_backends.py::test_support_alternative_backends[f_name_n_args15-torch] - AssertionError: Scalars are not close!
FAILED scipy/stats/tests/test_stats.py::TestDescribe::test_describe_numbers[torch] - AssertionError: Tensor-likes are not equal!
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ihfft[cupy] - AssertionError: 
  • here: 67 failed, 51998 passed, 11316 skipped, 157 xfailed, 13 xpassed in 58.12s

Mostly looks like the chandrupatla stuff discussed above on torch + GPU, which is already on main.

With the patch below

--- a/scipy/optimize/_tstutils.py
+++ b/scipy/optimize/_tstutils.py
@@ -44,6 +44,7 @@ from random import random
 import numpy as np
 
 from scipy.optimize import _zeros_py as cc
+from scipy._lib._array_api import array_namespace
 
 # "description" refers to the original functions
 description = """
@@ -887,18 +888,21 @@ fun6.root = 0
 
 
 def fun7(x):
-    return 0 if abs(x) < 3.8e-4 else x*np.exp(-x**(-2))
+    xp = array_namespace(x)
+    return 0 if abs(x) < 3.8e-4 else x*xp.exp(-x**(-2))
 fun7.root = 0
 
 
 def fun8(x):
+    xp = array_namespace(x)
     xi = 0.61489
-    return -(3062*(1-xi)*np.exp(-x))/(xi + (1-xi)*np.exp(-x)) - 1013 + 1628/x
+    return -(3062*(1-xi)*xp.exp(-x))/(xi + (1-xi)*xp.exp(-x)) - 1013 + 1628/x
 fun8.root = 1.0375360332870405
 
 
 def fun9(x):
-    return np.exp(x) - 2 - 0.01/x**2 + .000002/x**3
+    xp = array_namespace(x)
+    return xp.exp(x) - 2 - 0.01/x**2 + .000002/x**3
 fun9.root = 0.7032048403631358
 
 # Each "chandropatla" test case has

most of the torch GPU failures become RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source..

And with this additional patch:

--- a/scipy/optimize/_chandrupatla.py
+++ b/scipy/optimize/_chandrupatla.py
@@ -187,7 +187,7 @@ def _chandrupatla(func, a, b, *, args=(), xatol=None, xrtol=None,
         # If the bracket is no longer valid, report failure (unless a function
         # tolerance is met, as detected above).
         i = (xp_sign(work.f1) == xp_sign(work.f2)) & ~stop
-        NaN = xp.asarray(xp.nan)
+        NaN = xp.asarray(xp.nan, dtype=work.xmin.dtype)
         work.xmin[i], work.fmin[i], work.status[i] = NaN, NaN, eim._ESIGNERR
         stop[i] = True

it drops to 15 failures. There are some typos in some of the array API-converted tests it seems. More np->xp conversion fixes another GPU failure:

--- a/scipy/optimize/tests/test_chandrupatla.py
+++ b/scipy/optimize/tests/test_chandrupatla.py
@@ -656,11 +656,11 @@ class TestChandrupatla(TestScalarRootFinders):
         x1, x2 = bracket
         f0 = xp_minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
         res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
-        xp_assert_less(np.abs(res1.fun), 1e-3*f0)
+        xp_assert_less(xp.abs(res1.fun), 1e-3*f0)
         kwargs['frtol'] = 1e-6
         res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
-        xp_assert_less(np.abs(res2.fun), 1e-6*f0)
-        xp_assert_less(np.abs(res2.fun), np.abs(res1.fun))
+        xp_assert_less(xp.abs(res2.fun), 1e-6*f0)
+        xp_assert_less(xp.abs(res2.fun), xp.abs(res1.fun))

I ran out of steam there, but basically this branch doesn't seem to introduce much that isn't already broken on main in my hands.

@mdhaber
Copy link
Contributor

mdhaber commented May 18, 2024

@tylerjereddy would you be willing to open a PR with these patches that I can merge?

@tylerjereddy
Copy link
Contributor

ok

tylerjereddy added a commit to tylerjereddy/scipy that referenced this pull request May 18, 2024
* Addresses some of my points at:
scipy#20085 (review)
and seems to fix about 55 GPU-based array API
test failures

[skip cirrus] [skip circle]
@mdhaber
Copy link
Contributor

mdhaber commented May 18, 2024

I can make any remaining fixes in that PR if need be. It turned out that we forgot to add test_chandrupatla to the array API job so it looked like CI was passing, and of course it passed locally (even with torch CPU and array_api_strict).

rgommers pushed a commit that referenced this pull request May 18, 2024
Addresses some of my points at:
#20085 (review)
and seems to fix about 55 GPU-based array API test failures

Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
@@ -207,7 +205,6 @@ def test_mlab_linkage_conversion_empty(self, xp):
xp_assert_equal(from_mlab_linkage(X), X)
xp_assert_equal(to_mlab_linkage(X), X)

@skip_xp_backends(cpu_only=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_mlab_linkage converts with np.asarray, so I'll put these back.

@rgommers
Copy link
Member

@lucascolley FYI some of the cpu_only=True tests that pass with JAX on GPU are doing so because np.asarray(a_jax_cuda_array) works. However, it is very inefficient, the code is often not correct anyway with JAX because the reverse xp.asarray call doesn't put data back on the GPU, and it won't work for either CuPy or PyTorch. So I'll undo such test changes.

I found that having an environment with both JAX and CuPy and testing with -b all is a nice way to test, since they have quite different constraints. PyTorch is harder to install in the same env as JAX, but if things work for both JAX and CuPy then they'll most likely work for PyTorch as well.

@lucascolley
Copy link
Member Author

I found that having an environment with both JAX and CuPy and testing with -b all is a nice way to test, since they have quite different constraints. PyTorch is harder to install in the same env as JAX, but if things work for both JAX and CuPy then they'll most likely work for PyTorch as well.

Nice. Moving forward I will probably have one env with JAX + CuPy and another with PyTorch + CuPy + array-api-strict, and test with both. Things will be easier once I'm back with a GPU.

@rgommers
Copy link
Member

The one Windows failure is unrelated:

 FAILED scipy\optimize\tests\test_constraint_conversion.py::TestNewToOld::test_individual_constraint_objects

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, time to give it a go. Thanks a lot @lucascolley and all reviewers!

Follow-up steps include:

  1. deal with item/slice assignment, reducing/removing skips related to that
  2. look at a callback mechanism to make jax.jit work
  3. check how things look on TPU (e.g. in a Kaggle notebook, see discussion higher up)

I plan to have a look at 1 and 2 tonight.

Unrelated to JAX follow-ups:

  • a few more CuPy test failures to deal with in main
  • deal with other test failures related to nan_policy

@rgommers rgommers merged commit 7192a1c into scipy:main May 18, 2024
@lucascolley lucascolley deleted the jax branch May 18, 2024 10:21
@lucascolley
Copy link
Member Author

thanks Ralf and all reviewers for all of the help here! I plan to have a look at Dask in a few months' time, but anyone else, feel free to tackle it if you get to it before me.

A reminder that gh-19900 looks ready to me and should help eliminate some of the GPU failures Tyler was seeing. But no rush if it looks like more work is needed.

FYI @izaid , scipy._lib._array_api.scipy_namespace_for exists now.

I'm ~1 week out from finals now, so I'll not be working on any PRs for a while. See you on the other side!

@mdhaber
Copy link
Contributor

mdhaber commented May 18, 2024

What are the CuPy and nan_policy failures? I can probably fix them today.

@rgommers
Copy link
Member

rgommers commented May 18, 2024

Good luck with your finals Lucas!

What are the CuPy and nan_policy failures? I can probably fix them today.

CuPy failures are taken care of in gh-19900. They were:

_______________________________ TestFFTThreadSafe.test_ihfft[cupy] ________________________________
scipy/fft/tests/test_basic.py:460: in test_ihfft
    self._test_mtsame(fft.ihfft, a, xp=xp)
        a          = array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]])
        self       = <scipy.fft.tests.test_basic.TestFFTThreadSafe object at 0x754efabfef60>
        xp         = <module 'cupy' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/__init__.py'>
scipy/fft/tests/test_basic.py:434: in _test_mtsame
    q.get(timeout=5), expected,
        args       = (array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
   ....,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]]),)
        expected   = array([[1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-0.j],
       [1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-...  [1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-0.j],
       [1.-0.j, 0.-0.j, 0.-0.j, ..., 0.+0.j, 0.+0.j, 0.-0.j]])
        func       = <uarray multimethod 'ihfft'>
        i          = 0
        q          = <queue.Queue object at 0x754ec6b60290>
        self       = <scipy.fft.tests.test_basic.TestFFTThreadSafe object at 0x754efabfef60>
        t          = [<Thread(Thread-354 (worker), stopped 128979476416192)>, <Thread(Thread-355 (worker), stopped 128979444958912)>, <Thre...>, <Thread(Thread-358 (worker), stopped 128979677742784)>, <Thread(Thread-359 (worker), stopped 128979667257024)>, ...]
        worker     = <function TestFFTThreadSafe._test_mtsame.<locals>.worker at 0x754ed013be20>
        xp         = <module 'cupy' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/__init__.py'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/queue.py:179: in get
    raise Empty
E   _queue.Empty
        block      = True
        endtime    = 145.408170417
        remaining  = -0.0001208719999965524
        self       = <queue.Queue object at 0x754ec6b60290>
        timeout    = 5

During handling of the above exception, another exception occurred:
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/runner.py:341: in from_call
    result: Optional[TResult] = func()
        cls        = <class '_pytest.runner.CallInfo'>
        duration   = 5.902584181999998
        excinfo    = <ExceptionInfo PytestUnhandledThreadExceptionWarning('Exception in thread Thread-368 (worker)\n\nTraceback (most recen.../cuda/cufft.pyx", line 169, in cupy.cuda.cufft.check_result\ncupy.cuda.cufft.CuFFTError: CUFFT_EXEC_FAILED\n') tblen=9>
        func       = <function call_and_report.<locals>.<lambda> at 0x754f987cff60>
        precise_start = 139.640223733
        precise_stop = 145.542807915
        reraise    = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
        result     = None
        start      = 1716057866.3666894
        stop       = 1716057872.2692754
        when       = 'call'
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/runner.py:241: in <lambda>
    lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise
        item       = <Function test_ihfft[cupy]>
        kwds       = {}
        runtest_hook = <HookCaller 'pytest_runtest_call'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/pluggy/_hooks.py:513: in __call__
    return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
        firstresult = False
        kwargs     = {'item': <Function test_ihfft[cupy]>}
        self       = <HookCaller 'pytest_runtest_call'>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/pluggy/_manager.py:120: in _hookexec
    return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
        firstresult = False
        hook_name  = 'pytest_runtest_call'
        kwargs     = {'item': <Function test_ihfft[cupy]>}
        methods    = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/home/rgommers/mambaforge/envs/scipy-dev-jax-cu...=None>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x754fd2f5fbc0>>, ...]
        self       = <_pytest.config.PytestPluginManager object at 0x754fdc39a270>
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/threadexception.py:87: in pytest_runtest_call
    yield from thread_exception_runtest_hook()
/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/_pytest/threadexception.py:77: in thread_exception_runtest_hook
    warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
E   pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-368 (worker)
E   
E   Traceback (most recent call last):
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
E       self.run()
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/threading.py", line 1010, in run
E       self._target(*self._args, **self._kwargs)
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/tests/test_basic.py", line 419, in worker
E       q.put(func(*args))
E             ^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/_backend.py", line 28, in __ua_function__
E       return fn(*args, **kwargs)
E              ^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/_basic_backend.py", line 90, in ihfft
E       return _execute_1D('ihfft', _pocketfft.ihfft, x, n=n, axis=axis, norm=norm,
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/fft/_basic_backend.py", line 34, in _execute_1D
E       return xp_func(x, n=n, axis=axis, norm=norm)
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/_lib/array_api_compat/_internal.py", line 28, in wrapped_f
E       return f(*args, xp=xp, **kwargs)
E              ^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/code/scipy/build-install/lib/python3.12/site-packages/scipy/_lib/array_api_compat/common/_fft.py", line 147, in ihfft
E       res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
E             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 1050, in ihfft
E       return rfft(a, n, axis, _swap_direction(norm)).conj()
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 840, in rfft
E       return _fft(a, (n,), (axis,), norm, cufft.CUFFT_FORWARD, 'R2C')
E              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 248, in _fft
E       a = _exec_fft(a, direction, value_type, norm, axes[-1], overwrite_x)
E           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E     File "/home/rgommers/mambaforge/envs/scipy-dev-jax-cu12/lib/python3.12/site-packages/cupy/fft/_fft.py", line 191, in _exec_fft
E       plan.fft(a, out, direction)
E     File "cupy/cuda/cufft.pyx", line 500, in cupy.cuda.cufft.Plan1d.fft
E     File "cupy/cuda/cufft.pyx", line 520, in cupy.cuda.cufft.Plan1d._single_gpu_fft
E     File "cupy/cuda/cufft.pyx", line 1145, in cupy.cuda.cufft.execD2Z
E     File "cupy/cuda/cufft.pyx", line 169, in cupy.cuda.cufft.check_result
E   cupy.cuda.cufft.CuFFTError: CUFFT_EXEC_FAILED
        cm         = <_pytest.threadexception.catch_threading_exception object at 0x754efabff440>
        msg        = 'Exception in thread Thread-368 (worker)\n\nTraceback (most recent call last):\n  File "/home/rgommers/mambaforge/envs...File "cupy/cuda/cufft.pyx", line 169, in cupy.cuda.cufft.check_result\ncupy.cuda.cufft.CuFFTError: CUFFT_EXEC_FAILED\n'
        thread_name = 'Thread-368 (worker)'
===================================== short test summary info =====================================
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ifft[cupy] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-112 (worker)
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_rfft[cupy] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-176 (worker)
FAILED scipy/fft/tests/test_basic.py::TestFFTThreadSafe::test_ihfft[cupy] - pytest.PytestUnhandledThreadExceptionWarning: Exception in thread Thread-368 (worker)

nan_policy failures I gave higher up, will circle back to those after reviewing gh-19900. EDIT: see gh-20748.

rgommers added a commit to rgommers/scipy that referenced this pull request May 19, 2024
This makes things work with JAX, at a slight readability cost.
Follow up to scipy#20085.
@rgommers
Copy link
Member

rgommers commented May 19, 2024

deal with item/slice assignment, reducing/removing skips related to that

I worked some more on this, adding at_set/at_add etc. for in-place equivalents: https://github.com/scipy/scipy/compare/main...rgommers:scipy:array-types-inplace-ops?expand=1. It's a slightly readability hit to replace Z[i//2,1] = -2 with Z = at_set(Z, (i//2, 1), -2), but acceptable in many places (an opt-in mode for JAX to recognize regular in-place syntax would be way better though).

Using it, some good news and some less good. I could get cluster.whiten to work with jax.jit on CPU and GPU with some minor tweaks. And it helps performance:

    whiten_jit(face).block_until_ready()  # do the JIT compilation
    face_gpu = face  # JAX defaults to GPU if that is available
    face_cpu = jax.device_put(face, jax.devices('cpu')[0])
    face_np = np.asarray(face)

    %timeit cluster.vq.whiten(face_np)   #  22 ms
    %timeit cluster.vq.whiten(face_cpu)  #   6 ms
    %timeit whiten_jit(face_cpu)         #   3 ms
    %timeit cluster.vq.whiten(face_gpu)  # 700 us
    %timeit whiten_jit(face_gpu)         # 275 us

The less good news is that the at[idx].set() syntax still doesn't work for boolean indexing:

>>> import jax.numpy as jnp
>>> import jax
>>>
>>> def func(x, idx, value):
...     return x.at[idx].set(value)
...
>>> func_jit = jax.jit(func)
>>>
>>> x = jnp.arange(5)
>>> idx = x < 3
>>>
>>> func(x, idx, 99)
Array([99, 99, 99,  3,  4], dtype=int32)
>>> func_jit(x, idx, 99)
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

The explanation under Boolean indexing into JAX arrays https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError isn't quite satisfactory. There are no dynamic shapes here, so it could work just fine. If the answer is to always use where then it's just missing support for syntax that could be translated to jnp.where internally. So it seems like we need further branching under is_jax like this:

def at_set(
        x : Array,
        idx: Array | int | slice,
        val: Array | int | float | complex,
        *,
        xp: ModuleType | None = None,
    ) -> Array:
    """In-place update. Use only if no views are involved."""
    xp = array_namespace(x) if xp is None else xp
    if is_jax(xp):
        if xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x, val)
        else:
            x = x.at[idx].set(val)
    else:
        x[idx] = val
    return x

Which is slower - and also I'm not sure if jax.jit is then going to complain about the if xp.isdtype line because of Python control flow with an array involved EDIT: that works, with a tweak (even slower):

    if is_jax(xp):
        if hasattr(idx, 'dtype') and xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x * val, x)
        else:
            x = x.at[idx].multiply(val)
    else:
        x[idx] *= val

I'll look at it some more - this may be better suited for data-apis/array-api#609.

@mdhaber
Copy link
Contributor

mdhaber commented May 19, 2024

I think it would be worth adding something that works for now, even if it's not great. It would avoid all the test skips and make it more obvious what capabilities we need. Once something better comes along, it will be easy to replace. It's probably better than using where, which we're tempted to use otherwise.

@rgommers
Copy link
Member

I think it would be worth adding something that works for now, even if it's not great.

Yeah maybe - I don't want to go too fast though, and add a bunch of code we may regret. Looks like the new version (I edited my comment and pushed a new commit) works though, and is still very fast with JAX.

It's probably better than using where, which we're tempted to use otherwise.

Let's make sure not to do things like that. Using where could potentially be bad for performance with numpy, which would not be helpful. Skips are better for now.

@jakevdp
Copy link
Member

jakevdp commented May 20, 2024

@rgommers FYI I managed to implement the scalar boolean scatter in JAX, and it will be available in the next release. Turns out we had all the necessary logic there already – I just needed to put it together! google/jax#21305

@rgommers
Copy link
Member

Great! Thanks @jakevdp. Looks like a small patch that I can try out pretty easily on top of JAX 0.4.28 - will give it a go later this week.

(note to self, since comments are hard to find in this PR: the relevant comment here is #20085 (comment))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) Cython Issues with the internal Cython code base enhancement A new feature or improvement scipy.cluster scipy.fft scipy._lib scipy.special scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants