Skip to content

Commit

Permalink
Delete refactored function, move changes over (#126407)
Browse files Browse the repository at this point in the history
Oops, in #125610 I moved this function to runtime_wrappers.py, but forgot to delete the old one. #126234 then modified it which would do nothing, so I'm applying the change correctly now and deleting the function as I intended.

Pull Request resolved: #126407
Approved by: https://github.com/eellison
  • Loading branch information
jamesjwu authored and pytorchmergebot committed May 17, 2024
1 parent ab307a8 commit 078e530
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 29 deletions.
20 changes: 0 additions & 20 deletions torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,6 @@ def _force_contiguous(x):
return x


def _compute_output_meta_with_inductor_strides(fw_module, fwd_output_strides):
out = [n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0])]
# will only be set for inductor
if not fwd_output_strides:
return out

from torch.fx.experimental.symbolic_shapes import statically_known_true

for i in range(len(out)):
if not isinstance(out[i], Tensor):
continue
if all(
statically_known_true(s1 == s2)
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
):
continue
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
return out


# See Note [Tangents must be contiguous, Part 2]
def coerce_runtime_tangent(x, metadata_tensor):
if not isinstance(x, torch.Tensor):
Expand Down
21 changes: 12 additions & 9 deletions torch/_functorch/_aot_autograd/runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,18 @@ def _compute_output_meta_with_inductor_strides(self):
fwd_output_strides = self.fwd_output_strides
if not fwd_output_strides:
return out
with TracingContext.get().fake_mode.shape_env.suppress_guards():
for i in range(len(out)):
if not isinstance(out[i], Tensor):
continue
if all(
s1 == s2 for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
):
continue
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])

from torch.fx.experimental.symbolic_shapes import statically_known_true

for i in range(len(out)):
if not isinstance(out[i], Tensor):
continue
if all(
statically_known_true(s1 == s2)
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
):
continue
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
return out

# To be called post compile
Expand Down

0 comments on commit 078e530

Please sign in to comment.