You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been converting a model from PyTorch to tinygrad, and ran into an issue with einsums (defined in Tensor.py).
The einsum "...id, ...jd -> ...ij" executes fine in PyTorch but throws an error (below) when executed in tinygrad on the exact same input data. I've included a minimal example below the error for debugging.
importtinygradimporttorchimportnumpyasnp# Define the input tensorinp1_np=np.random.randn(16, 29, 256).astype(np.float32)
inp2_np=np.random.randn(16, 29, 256).astype(np.float32)
# Convert the input tensors to torch tensorsinp1=torch.tensor(inp1_np)
inp2=torch.tensor(inp2_np)
out=torch.einsum("...id, ...jd -> ...ij", inp1, inp2)
# Evaluate the output tensorout_np=out.numpy()
# Convert the input tensors to tinygrad tensorsinp1_tg=tinygrad.Tensor(inp1_np)
inp2_tg=tinygrad.Tensor(inp2_np)
# Perform the einsum operation using tinygrad - will throw error!out_tg=tinygrad.Tensor.einsum("...id, ...jd -> ...ij", inp1_tg, inp2_tg)
# Evaluate the output tensorout_tg_np=out_tg.numpy()
I looked at the locals when the error appeared, and it looks like the "..." part of the einsum is being interpreted as an index for each of the dots. I might be wrong, but I think this is supposed to represent an arbitrary number of indices, not specific indices.
Let me know if you need any more information!
The text was updated successfully, but these errors were encountered:
I've been converting a model from PyTorch to tinygrad, and ran into an issue with einsums (defined in Tensor.py).
The einsum "...id, ...jd -> ...ij" executes fine in PyTorch but throws an error (below) when executed in tinygrad on the exact same input data. I've included a minimal example below the error for debugging.
I looked at the locals when the error appeared, and it looks like the "..." part of the einsum is being interpreted as an index for each of the dots. I might be wrong, but I think this is supposed to represent an arbitrary number of indices, not specific indices.
Let me know if you need any more information!
The text was updated successfully, but these errors were encountered: