New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add propagate_real_tensors mode for unbacked #125115
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125115
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 8bb7f84 with merge base c1a3fcf (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 5ac57facea69da09303b9708e0ed85ef567b5df5 Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 8bb80b5522e80b619bdf436bd3087e8cc8514203 Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 681ec87050dcac63dcdd001347f6bf2d9a55965e Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: c3a8e2925ff4f6f3b72de884b5dc472d342e1b46 Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 5309d4133d565d2a0dde661c195ba52373d1e9d2 Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 8e5355ff3d0f8f2807de3158fab14318ab8435db Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 1f97e037a55711b1db7ce2f805d6b24be3ef8f5c Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 2eeb18f9e64aaa13b1178678907d1536290a10dc Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 286bdada4b86fe1273015a0e16ab4c630e2a07fd Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 614a79b06a398a71666725dc3101b4331f081cab Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 91e4a445db2aeb871b014b541aac6f322691370c Pull Request resolved: #125115
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: d90e89f0eabad52bb912c2a511b0e854e7809b9c Pull Request resolved: #125115
def go(t, real_t): | ||
if isinstance(t, FakeTensor): | ||
# NB: unconditionally overwrite | ||
t.real_tensor = real_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In constructor there is assert that real_tensor is not a FakeTensor.
So here it is breaking this invariant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, real_t is a real tensor :)
I think will be also good to have a separate dynamo.config.use_propagate_real_tensors for data dep solving. |
What I wrote in chat:
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one. This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are. I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this: ``` WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False ``` Potential later follow ups: * Improve the warning messages (in particular, should provide user frames) * GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #125115 Approved by: https://github.com/IvanKobzarev
Summary: A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one. This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are. I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this: ``` WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False ``` Potential later follow ups: * Improve the warning messages (in particular, should provide user frames) * GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live Signed-off-by: Edward Z. Yang <ezyang@meta.com> X-link: pytorch/pytorch#125115 Approved by: https://github.com/IvanKobzarev Reviewed By: kit1980 Differential Revision: D56915030 Pulled By: ezyang fbshipit-source-id: f04687107bbd3e35e7bdba45998f75be6388debf
Stack from ghstack (oldest at bottom):
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.
This PR adds a new mode
torch._functorch.config.fake_tensor_propagate_real_tensors
which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:
Potential later follow ups:
Signed-off-by: Edward Z. Yang ezyang@meta.com
cc @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng