You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been experimenting with Nx this past week and since I'm working on an M1 Mac I've been swapping between the :cpu and :mps devices in Torchx to compare relative performance.
I noticed when trying to use some of the indexed functions like Nx.indexed_add() and Nx.indexed_put() with the :mps device, I receive the following error:
** (RuntimeError) Torchx: All inputs of where should have same/compatible number of dims in NIF.where/3
(torchx 0.5.1) lib/torchx.ex:445: Torchx.unwrap!/1
(torchx 0.5.1) lib/torchx.ex:448: Torchx.unwrap_tensor!/2
(torchx 0.5.1) lib/torchx/backend.ex:1403: Torchx.Backend.select/4
(torchx 0.5.1) lib/torchx/backend.ex:536: Torchx.Backend.as_torchx_linear_indices/2
(torchx 0.5.1) lib/torchx/backend.ex:504: Torchx.Backend.indexed/5
Note: This works fine then using the :cpu device, only the :mps device seems to cause these issues
The text was updated successfully, but these errors were encountered:
I've been experimenting with Nx this past week and since I'm working on an M1 Mac I've been swapping between the
:cpu
and:mps
devices in Torchx to compare relative performance.I noticed when trying to use some of the indexed functions like
Nx.indexed_add()
andNx.indexed_put()
with the:mps
device, I receive the following error:Note: This works fine then using the
:cpu
device, only the:mps
device seems to cause these issuesThe text was updated successfully, but these errors were encountered: