[Functorch] vmap over index_select expands the output #115347
Labels
actionable
high priority
module: functorch
Pertaining to torch.func or pytorch/functorch
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
When calling
torch.vmap
overtorch.index_select
with a batched index, the result does not match Jax's results:Expected
PyTorch
Versions
PT '2.2.0.dev20231128'
cc @ezyang @gchanan @zou3519 @kadeng @Chillee @samdow @kshitij12345 @janeyx99
The text was updated successfully, but these errors were encountered: