Skip to content

Commit

Permalink
Add propagate_real_tensors mode for unbacked (#125115)
Browse files Browse the repository at this point in the history
Summary:
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```

Potential later follow ups:

* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

X-link: pytorch/pytorch#125115
Approved by: https://github.com/IvanKobzarev

Reviewed By: kit1980

Differential Revision: D56915030

Pulled By: ezyang

fbshipit-source-id: f04687107bbd3e35e7bdba45998f75be6388debf
  • Loading branch information
ezyang authored and facebook-github-bot committed May 3, 2024
1 parent 416b8b5 commit a8879af
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
6 changes: 4 additions & 2 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Expand Up @@ -311,7 +311,9 @@ def _fn(*args, **kwargs):
return _fn


def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=None):
def make_test_cls_with_patches(
cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x
):
DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
DummyTestClass.__qualname__ = DummyTestClass.__name__

Expand All @@ -326,7 +328,7 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
setattr(DummyTestClass, new_name, new_fn)
setattr(DummyTestClass, new_name, decorator(new_fn))
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(DummyTestClass, name):
setattr(DummyTestClass, name, getattr(cls, name))
Expand Down
11 changes: 11 additions & 0 deletions userbenchmark/dynamo/dynamobench/common.py
Expand Up @@ -75,6 +75,7 @@
graph_break_reasons,
maybe_enable_compiled_autograd,
)
import torch._functorch.config
from torch._functorch.aot_autograd import set_model_name
from torch._inductor import config as inductor_config, metrics
from torch._subclasses.fake_tensor import FakeTensorMode
Expand Down Expand Up @@ -3155,6 +3156,11 @@ def get_example_inputs(self):
action="store_true",
help="Runs a dynamic shapes version of the benchmark, if available.",
)
parser.add_argument(
"--propagate-real-tensors",
action="store_true",
help="Capture as much data dependent as you can by unsoundly propagating real tensors",
)
parser.add_argument(
"--dynamic-batch-only",
action="store_true",
Expand Down Expand Up @@ -3603,6 +3609,11 @@ def run(runner, args, original_dir=None):
if args.dynamic_shapes:
if not args.dynamic_batch_only:
torch._dynamo.config.assume_static_by_default = False
if args.propagate_real_tensors:
# TODO: Separate flag for data dependent
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._functorch.config.fake_tensor_propagate_real_tensors = True
if args.specialize_int:
torch._dynamo.config.specialize_int = True
if args.ci:
Expand Down

0 comments on commit a8879af

Please sign in to comment.