Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sum(keepdims=True) should work in pallas #20921

Open
voznesenskym opened this issue Apr 24, 2024 · 0 comments
Open

sum(keepdims=True) should work in pallas #20921

voznesenskym opened this issue Apr 24, 2024 · 0 comments
Assignees
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@voznesenskym
Copy link

voznesenskym commented Apr 24, 2024

Description

I am writing a pallas kernel, where one of the lines:

frobenius_sq_norm = square_norm(w_tl).sum(keepdims=True)... includes a sum() w/ keepdims. I don't expect sum() to work without keepdims, at that would produce a scalar. However, for keepdims=True, being vec->vec, it should work.

Cannot lower reductions to scalar. Reduce to one element vector instead, using keepdims=True.

This is because the actual implementation is first removing the dimension, and then adding it back in.

In reductions.py in jax.numpy, we can find:

if keepdims:
   result = lax.expand_dims(result, pos_dims)

The real fix would probably be to pass along keepdims to all leaf locations where we actually end up invoking the op, and ensure that that op respects it and preserves the vec, instead of repackaging a scalar into a vec.

@voznesenskym voznesenskym added the bug Something isn't working label Apr 24, 2024
@sharadmv sharadmv added the pallas Issues pertaining to Pallas (GPU or TPU) label Apr 24, 2024
@apaszke apaszke self-assigned this Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

3 participants