Skip to content

Commit

Permalink
Add propagate_real_tensors mode for unbacked (#125115)
Browse files Browse the repository at this point in the history
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```

Potential later follow ups:

* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #125115
Approved by: https://github.com/IvanKobzarev
  • Loading branch information
ezyang authored and petrex committed May 3, 2024
1 parent 99a7df1 commit 67b06ad
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 61 deletions.
11 changes: 11 additions & 0 deletions benchmarks/dynamo/common.py
Expand Up @@ -75,6 +75,7 @@
graph_break_reasons,
maybe_enable_compiled_autograd,
)
import torch._functorch.config
from torch._functorch.aot_autograd import set_model_name
from torch._inductor import config as inductor_config, metrics
from torch._subclasses.fake_tensor import FakeTensorMode
Expand Down Expand Up @@ -3155,6 +3156,11 @@ def get_example_inputs(self):
action="store_true",
help="Runs a dynamic shapes version of the benchmark, if available.",
)
parser.add_argument(
"--propagate-real-tensors",
action="store_true",
help="Capture as much data dependent as you can by unsoundly propagating real tensors",
)
parser.add_argument(
"--dynamic-batch-only",
action="store_true",
Expand Down Expand Up @@ -3603,6 +3609,11 @@ def run(runner, args, original_dir=None):
if args.dynamic_shapes:
if not args.dynamic_batch_only:
torch._dynamo.config.assume_static_by_default = False
if args.propagate_real_tensors:
# TODO: Separate flag for data dependent
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._functorch.config.fake_tensor_propagate_real_tensors = True
if args.specialize_int:
torch._dynamo.config.specialize_int = True
if args.ci:
Expand Down
17 changes: 17 additions & 0 deletions test/dynamo/test_misc.py
Expand Up @@ -10525,6 +10525,23 @@ def fn(x, d):
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(4), d)

@unittest.skipIf(not TEST_CUDA, "requires cuda")
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_interpolate_propagate_real_tensors(self):
@torch.compile(backend="eager", fullgraph=True)
def f(mask, box):
# u0, u1 = mask.tolist()
mask = torch.randn(1, 1, 30, 30, device="cuda")
h, w = box.tolist()
return torch.nn.functional.interpolate(
mask, (h, w), mode="bilinear", align_corners=False
)

f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda"))

def test_custom_iter_dict(self):
class ReversedDict(dict):
def __iter__(self):
Expand Down
6 changes: 6 additions & 0 deletions test/test_dynamic_shapes.py
Expand Up @@ -512,6 +512,12 @@ def test_data_dependent_guard(self):
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))

def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
self.assertEqual(bool(s0 == 0), True)

def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
Expand Down
106 changes: 82 additions & 24 deletions test/test_fake_tensor.py
Expand Up @@ -6,6 +6,7 @@
instantiate_parametrized_tests, TemporaryFileName)
import torch
import torch._dynamo
from torch._dynamo.testing import make_test_cls_with_patches
import itertools
import numpy as np
from torch.testing._internal.jit_utils import RUN_CUDA
Expand Down Expand Up @@ -53,6 +54,10 @@
torch._dynamo.config.fake_tensor_cache_enabled = True
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True

def expectedFailurePropagateRealTensors(fn):
fn._expected_failure_propagate_real_tensors = True
return fn

class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
Expand Down Expand Up @@ -83,18 +88,22 @@ def test_basic(self):
def test_custom_op_fallback(self):
from torch.library import Library, impl

test_lib = Library("my_test_op", "DEF") # noqa: TOR901
test_lib.define('foo(Tensor self) -> Tensor')
try:
test_lib = Library("my_test_op", "DEF") # noqa: TOR901
test_lib.define('foo(Tensor self) -> Tensor')

@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()
@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()

x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)
x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)

finally:
test_lib._destroy()

def test_parameter_instantiation(self):
with FakeTensorMode():
Expand Down Expand Up @@ -207,6 +216,8 @@ def test_fake_dispatch_keys(self):
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))

# TODO: functorch support for propagate real tensors
@expectedFailurePropagateRealTensors
def test_batch_tensor(self):
x = torch.rand((3, 4, 5))
b = _add_batch_dim(x, 0, 0)
Expand Down Expand Up @@ -392,10 +403,10 @@ def test_out_multi_device(self):
x = torch.rand([4])
y = torch.rand([4], device="cuda")

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
torch.sin(x, out=y)

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
x.add_(y)


Expand Down Expand Up @@ -559,6 +570,8 @@ def test_tolist(self):
x = torch.rand([10])
x.tolist()

# Propagate real tensors doesn't work with fake-on-fake
@expectedFailurePropagateRealTensors
def test_same_shape_env_preserved(self):
shape_env = ShapeEnv()
mode1 = FakeTensorMode(shape_env=shape_env)
Expand All @@ -578,6 +591,9 @@ def test_same_shape_env_preserved(self):
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
self.assertEqual(str(t2.size(0)), str(t1.size(0)))

# TODO: Support NJT. There's also some funny business with dynamic shapes
# which would need to be dealt with as well
@expectedFailurePropagateRealTensors
def test_jagged_fake_to_fake_preserved(self):
from torch.nested._internal.nested_tensor import jagged_from_list

Expand Down Expand Up @@ -736,7 +752,9 @@ def test_aten_index_multi_device(self):
x2 = torch.rand(4, 4, device="cuda")
i1 = torch.tensor([0, 1], device="cuda")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
# NB: This one does not work: cuda indices not allowed on cpu
# tensor
# r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)

y1 = torch.rand(4, device="cpu")
Expand All @@ -745,7 +763,7 @@ def test_aten_index_multi_device(self):
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
# self.checkType(r1, "cpu", ())
self.checkType(r2, "cuda", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "cuda", (4, 4))
Expand Down Expand Up @@ -774,6 +792,9 @@ def test__adaptive_avg_pool2d_backward(self):
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)

# Propagate real tensors doesn't work when original input arguments are
# fake
@expectedFailurePropagateRealTensors
def test_export_numpy(self):
class MyNumpyModel(torch.nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -801,6 +822,26 @@ def f(x):
self.assertEqual(r.size(), [3])


instantiate_parametrized_tests(FakeTensorTest)


def make_propagate_real_tensors_cls(cls):
cls = make_test_cls_with_patches(
cls,
"PropagateRealTensors",
"_propagate_real_tensors",
(torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
xfail_prop="_expected_failure_propagate_real_tensors",
decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"),
)
cls.__file__ = __file__
cls.__module__ = __name__
globals()[cls.__name__] = cls


make_propagate_real_tensors_cls(FakeTensorTest)


class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
Expand Down Expand Up @@ -891,6 +932,10 @@ def test_constant_propagate_through_functions(self):
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)


make_propagate_real_tensors_cls(FakeTensorConstHandling)


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
Expand All @@ -907,6 +952,11 @@ def test_fake(self, device, dtype, op):
optests.fake_check(op, args, kwargs)


make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
instantiate_device_type_tests(PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",)) # noqa: F821


class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
Expand Down Expand Up @@ -1018,16 +1068,17 @@ def test_no_ref_cycle(self):
assert y_weak() is None


make_propagate_real_tensors_cls(FakeTensorConverterTest)


class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
def get_aten_op(self, schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

@staticmethod
def get_all_aten_schemas():
def get_all_aten_schemas(self):
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
Expand Down Expand Up @@ -1178,6 +1229,10 @@ def forward(self, arg1, arg2, arg3):

# IMPORTANT!!! Always run even if CUDA is not available
def test_fake_cuda_no_init(self):
# Skip this test, we will try to run CUDA operations to real prop so
# it clearly will not work on CPU runner
if torch._functorch.config.fake_tensor_propagate_real_tensors:
return
with FakeTensorMode():
torch.empty(10, device='cuda')
torch.ones(10, device='cuda')
Expand Down Expand Up @@ -1236,6 +1291,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.assertEqual(mode.count, 0)


make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)


class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
Expand Down Expand Up @@ -1294,6 +1352,7 @@ def to_fake_tensor(x):
self.assertTrue(failed)


@expectedFailurePropagateRealTensors # Propagate real tensors doesn't work with fake-on-fake
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1321,9 +1380,11 @@ def forward(self, value, another_value=None, another_optional_value=None):
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)


@expectedFailurePropagateRealTensors # TODO: not sure about this one, kinda strange
def test_unbacked_shape_realloc(self):
def f(x):
return x.nonzero()

shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
Expand Down Expand Up @@ -1368,6 +1429,9 @@ def forward(self, x):
torch.load(state_dict_file, map_location="cpu") # scenario 2


make_propagate_real_tensors_cls(FakeTensorPropTest)


class FakeTensorSerialization(TestCase):
def test_serialization(self):
x = torch.tensor([0], device="cpu")
Expand Down Expand Up @@ -1706,11 +1770,5 @@ def test_inference_mode(self):
extract_tensor_metadata(res4),
)


instantiate_parametrized_tests(FakeTensorTest)

only_for = ("cpu", "cuda")
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)

if __name__ == "__main__":
run_tests()
17 changes: 17 additions & 0 deletions test/test_proxy_tensor.py
Expand Up @@ -26,6 +26,7 @@
from torch.utils._pytree import tree_map
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch import nn
import torch._functorch.config
import re

import functools
Expand Down Expand Up @@ -1544,6 +1545,22 @@ def f(a):

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_invalidate_nonzero_propagate_real_tensors(self):
def f(a):
b = a.clone()
x = b.nonzero()
x1 = b.nonzero()
x2 = b.nonzero()
assert x1.shape[0] == x2.shape[0]
b.normal_()
y = b.nonzero()
# Because you're not actually going to generate exactly zero with
# normal_ lol
assert x1.shape[0] == y.shape[0]

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

def test_sqrt_size(self):
def f(a):
return a / a.size(-1) ** 0.5
Expand Down
6 changes: 4 additions & 2 deletions torch/_dynamo/testing.py
Expand Up @@ -311,7 +311,9 @@ def _fn(*args, **kwargs):
return _fn


def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=None):
def make_test_cls_with_patches(
cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x
):
DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
DummyTestClass.__qualname__ = DummyTestClass.__name__

Expand All @@ -326,7 +328,7 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
setattr(DummyTestClass, new_name, new_fn)
setattr(DummyTestClass, new_name, decorator(new_fn))
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(DummyTestClass, name):
setattr(DummyTestClass, name, getattr(cls, name))
Expand Down

0 comments on commit 67b06ad

Please sign in to comment.