-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.export] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #126674
Labels
module: export
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
Can you please share the error message and what was the issue when you tried applying the suggestion in the comment? |
xmfan
added
needs reproduction
Someone else needs to try reproducing the issue given the instructions. No action needed from user
and removed
needs reproduction
Someone else needs to try reproducing the issue given the instructions. No action needed from user
labels
May 20, 2024
Sure. Thanks for looking at this.
The error message when I tried to apply the suggestion is the same.
And this is the code to try the suggestion: import torch
ep = torch.export.load('retina.pt2')
gm = ep.module()
gm(torch.rand(1, 3, 800, 1216)) # success
for node in ep.graph.nodes:
if "device" in node.kwargs:
kwargs = node.kwargs.copy()
kwargs["device"] = "cuda"
node.kwargs = kwargs
# Move state dict tensors to cuda
for k, v in ep.state_dict.items():
if isinstance(v, torch.nn.Parameter):
ep._state_dict[k] = torch.nn.Parameter(v.cuda())
else:
ep._state_dict[k] = v.cuda()
gm = ep.module()
gm(torch.rand(1, 3, 800, 1216).cuda()) # failed Also note that the link to this model is provided at the above issue. |
mlazos
added
the
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
label
May 22, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: export
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
The exported model failed to do inference on cuda.
retina.pt2
Maybe related to #121761 but the solution provided by this comment doesn't work.
@angelayi Could you have a look at this? Thank you.
(Sorry for not providing the code about exporting the model now because it's a bit complicated)
Versions
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang
The text was updated successfully, but these errors were encountered: