Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Possible bug] test suite requires value-dependent integer casting #205

Open
jakevdp opened this issue Nov 14, 2023 · 1 comment
Open

[Possible bug] test suite requires value-dependent integer casting #205

jakevdp opened this issue Nov 14, 2023 · 1 comment

Comments

@jakevdp
Copy link
Contributor

jakevdp commented Nov 14, 2023

The JAX array API I'm developing in google/jax#16099 hits a number of failures due to pytest calling APIs with integers larger than the maximum int64. For example:

In [1]: import jax

In [2]: jax.config.update('jax_enable_x64', True)

In [3]: val = 2 ** 63  # one of the values generated by hypothesis

In [4]: jax.numpy.less(val, 0)
---------------------------------------------------------------------------
OverflowError: An overflow was encountered while parsing an argument to a jitted computation, whose argument path is x1.

By contrast, this isn't a problem in numpy:

In [5]: import numpy as np

In [6]: np.less(val, 0)
False

The reason for this discrepancy is that numpy does value-dependent casting of Python ints:

In [7]: np.array(val - 1).dtype
Out[7]: dtype('int64')

In [8]: np.array(val).dtype
Out[8]: dtype('uint64')

JAX has made the deliberate decision to avoid these kinds of implicit value-dependent semantics, and raises an error in the second case:

In [9]: jax.numpy.array(val - 1).dtype
Out[9]: dtype('int64')

In [10]: jax.numpy.array(val).dtype
---------------------------------------------------------------------------
OverflowError: Python int 9223372036854775808 too large to convert to int64

This design decision results in the array API test failures mentioned above.

I would like to address this in the JAX array API branch, so my question is this: Is this value-dependent integer casting behavior part of the Array API specification?

  • If the answer is yes, then I can add functionality to the jax array api wrappers to handle these corner cases.
  • If the answer is no, then the fact that the test suite depends on such behavior should probably be considered a bug.

Do you have thoughts on how I should proceed? Thanks!

@rgommers
Copy link
Member

Is this value-dependent integer casting behavior part of the Array API specification?

No, it's explicitly forbidden. So the test looks buggy to me. The maximum value for it should be capped to a lower value.

Also, numpy is getting rid of most or all of this value-dependent behavior, so it will break for numpy 2.0 as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants