New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make take_along_axis an optional callback #1440
Make take_along_axis an optional callback #1440
Conversation
I'm still yet to implement the default implementation. Hence the failing tests. |
^axis -> Value.reshape(indices, new_axis_typespec) | ||
axis -> Value.iota(state.builder, axis, new_axis_typespec) | ||
end) | ||
|> Value.concatenate(indices_rank, full_indices_typespec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, this is the main transformation you need to make in order to convert take_along_axis into a gather. :) If you do this transformation and call gather
without the axes option, it should be equivalent to the work done here.
ba1a17c
to
3da08f3
Compare
@josevalim, I've committed that transformation. I'm having some trouble making the optional callback work in |
0863a3c
to
3da08f3
Compare
@Benjamin-Philip you can also remove the implementations from |
Fixed and pushed manually, thank you! |
Makes take_along_axis an optional callback, using the default
implementation for BinaryBackend and EXLA, and a custom one for
Torchx.
Closes #1366.