Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add propagate_real_tensors mode for unbacked (#125115)
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one. This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are. I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this: ``` WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False ``` Potential later follow ups: * Improve the warning messages (in particular, should provide user frames) * GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #125115 Approved by: https://github.com/IvanKobzarev
- Loading branch information
Showing
11 changed files
with
367 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.