Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored common upcast for integral-type accumulators #20842

Merged
merged 1 commit into from May 6, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Apr 19, 2024

Towards #20200

This PR refactors integer accumulator promotion from jax._src.numpy.reductions._reduction and applies it to _make_cumulative_reduction as well.

Temporarily disables tests for prod, sum, and trace since the 2023 API included breaking changes which are not yet accounted for in the tests repository.

@Micky774 Micky774 marked this pull request as draft April 19, 2024 20:17
@Micky774 Micky774 changed the title Refactored common integral type upcast Refactored common upcast for integral-type accumulators Apr 19, 2024
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from 16b7a14 to 52e08a3 Compare April 19, 2024 20:28
@Micky774 Micky774 marked this pull request as ready for review April 19, 2024 20:28
@Micky774 Micky774 force-pushed the array-api-default-promotion branch 2 times, most recently from 1b23953 to a7255f7 Compare April 22, 2024 16:29
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-default-promotion branch 2 times, most recently from d6f2e04 to a0e30af Compare April 23, 2024 14:44
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from a0e30af to 93cda81 Compare April 24, 2024 01:47
@Micky774
Copy link
Collaborator Author

@jakevdp should be ready for review

@jakevdp jakevdp self-assigned this Apr 26, 2024
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from 93cda81 to e00f7df Compare May 1, 2024 17:54
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from e00f7df to 23f146e Compare May 2, 2024 01:17
@Micky774 Micky774 marked this pull request as draft May 2, 2024 01:20
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from 23f146e to 6b249f0 Compare May 2, 2024 01:28
@Micky774 Micky774 marked this pull request as ready for review May 2, 2024 01:28
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from 6b249f0 to 09c4242 Compare May 2, 2024 01:32
jax/_src/dtypes.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-default-promotion branch from 09c4242 to 75758b2 Compare May 2, 2024 02:04
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 2, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented May 2, 2024

Could you sync against the updated main branch and resolve conflicts? Thanks!

@Micky774 Micky774 force-pushed the array-api-default-promotion branch from 75758b2 to cd288ee Compare May 2, 2024 20:03
@Micky774
Copy link
Collaborator Author

Micky774 commented May 2, 2024

Synced!

@jakevdp
Copy link
Collaborator

jakevdp commented May 2, 2024

It looks like this breaks some tests in lax_numpy_reducers_test.py. You should be able to repro by running

JAX_NUM_GENERATED_CASES=90 pytest -n auto tests/lax_numpy_reducers_test.py -k testCumulativeSum

It looks like this somehow changed the behavior for integer inputs narrower than int32.

@Micky774 Micky774 force-pushed the array-api-default-promotion branch 3 times, most recently from 561bb6c to e115797 Compare May 3, 2024 20:02
@Micky774
Copy link
Collaborator Author

Micky774 commented May 3, 2024

It looks like the core issue is that we don't allow bool dtypes for reductions due to limiting add and mul primitives to non-boolean numerical types, and we adjust for this by upcasting bool to int_ before accumulation, but this differs from NumPy behavior where the data is converted to bool but addition ins implemented only as an or operation, rather than a count of True values.

For now I've explicitly disallowed the use of dtype=bool, since I think adjusting that behavior is actually a bit of a deeper change.

Update: I've set it to keep bool-->int_ upcast for accumulation and then downcast back to bool. Should match NumPy and satisfy array API dtype behavior as well.

@Micky774 Micky774 force-pushed the array-api-default-promotion branch 3 times, most recently from ce9771e to e5c093d Compare May 3, 2024 20:32
@Micky774
Copy link
Collaborator Author

Micky774 commented May 3, 2024

@jakevdp should be good for another look now

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this PR makes ``promote_integers` a public keyword for all cumulative reductions. Do we need to let users see that? Could we do this in a way that doesn't affect the public-facing API?

@Micky774 Micky774 force-pushed the array-api-default-promotion branch from e5c093d to b22b109 Compare May 3, 2024 21:09
@Micky774
Copy link
Collaborator Author

Micky774 commented May 3, 2024

Good point. I've changed it so that we make a helper reduction _cumsum_with_promotion (as opposed to the usual cumsum), preserving the API for the other cumulative reductions.

@Micky774 Micky774 force-pushed the array-api-default-promotion branch from b22b109 to 34c5163 Compare May 6, 2024 15:13
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@copybara-service copybara-service bot merged commit 3d3cb0b into google:main May 6, 2024
14 checks passed
@Micky774 Micky774 deleted the array-api-default-promotion branch May 6, 2024 21:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants