Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] Handle inplace op aliasing errors #126474

Open
williamwen42 opened this issue May 16, 2024 · 1 comment
Open

[dynamo] Handle inplace op aliasing errors #126474

williamwen42 opened this issue May 16, 2024 · 1 comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@williamwen42
Copy link
Member

williamwen42 commented May 16, 2024

Discovered with #126341.

Original repro command: PYTORCH_TEST_WITH_DYNAMO=1 pytest test/test_torch.py::TestTorchDeviceTypeCPU::test_ternary_op_mem_overlap_cpu_float64.

test_binary_ufuncs.py::TestBinaryUfuncsCPU::test_binary_op_mem_overlap_cpu_float64, test_unary_ufuncs.py::TestUnaryUfuncsCPU::test_unary_out_op_mem_overlap_cpu_float64 also fail for similar reasons.

Repro:

import torch

def f(x, y):
    try:
        torch.addcmul(x[:-1], y, y, out=x[1:])
    except:
        x = x + 1
    return x

opt_f = torch.compile(f, backend="eager")
inp_x = torch.randn(5)
inp_y = torch.randn(4)
print(f(inp_x, inp_y))  # prints something
print(opt_f(inp_x, inp_y))  # errors!

Output:

$ TORCH_LOGS="graph_code" python playground3.py 
tensor([ 2.5932,  1.2259,  1.6720, -0.5317,  0.1922])
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code] TRACED GRAPH
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]  /data/users/williamwen/pytorch2/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]     def forward(self, L_x_: "f32[5]", L_y_: "f32[4]"):
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         l_x_ = L_x_
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         l_y_ = L_y_
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         # File: /data/users/williamwen/pytorch2/playground3.py:5 in f, code: torch.addcmul(x[:-1], y, y, out=x[1:])
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         getitem: "f32[4]" = l_x_[slice(None, -1, None)]
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         getitem_1: "f32[4]" = l_x_[slice(1, None, None)]
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         addcmul: "f32[4]" = torch.addcmul(getitem, l_y_, l_y_, out = getitem_1);  getitem = l_y_ = getitem_1 = None
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         return (l_x_,)
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code]         
V0516 16:11:50.134000 140543640467264 torch/_dynamo/output_graph.py:1277] [0/0] [__graph_code] 
Traceback (most recent call last):
  File "/data/users/williamwen/pytorch2/playground3.py", line 14, in <module>
    print(opt_f(inp_x, inp_y))  # errors!
  File "/data/users/williamwen/pytorch2/torch/_dynamo/eval_frame.py", line 414, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/playground3.py", line 3, in f
    def f(x, y):
  File "/data/users/williamwen/pytorch2/torch/_dynamo/eval_frame.py", line 548, in _fn
    return fn(*args, **kwargs)
  File "<eval_with_key>.1", line 9, in forward
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

Ideally, this error should be caught during tracing. Also, the x = x + 1 line is skipped by dynamo.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@williamwen42 williamwen42 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo labels May 16, 2024
@williamwen42 williamwen42 added the module: correctness (silent) issue that returns an incorrect result silently label May 28, 2024
@williamwen42
Copy link
Member Author

If we run on inductor backend, e.g.

import torch

def f(x, y):
    try:
        torch.addcmul(x[:-1], y, y, out=x[1:])
    except:
        x = x + 1
    return x

opt_f = torch.compile(f, backend="inductor")
inp_x = torch.randn(5).cuda()
inp_y = torch.randn(4).cuda()
print(f(inp_x.clone(), inp_y.clone()))
print(f(inp_x.clone(), inp_y.clone()))
print(opt_f(inp_x.clone(), inp_y.clone()))

we actually do not error out, but the result is incorrect:

tensor([1.4259, 2.6361, 0.4578, 2.3654, 1.9228], device='cuda:0')
tensor([1.4259, 2.6361, 0.4578, 2.3654, 1.9228], device='cuda:0')
tensor([ 0.4259,  0.4260,  2.2543, -0.5243,  4.2885], device='cuda:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

1 participant