Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propagate_real_tensors mode for unbacked #125115

Closed
wants to merge 12 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Apr 28, 2024

Stack from ghstack (oldest at bottom):

A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode torch._functorch.config.fake_tensor_propagate_real_tensors which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True                                                
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True                                             
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True                                             
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False                                            
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True                                                
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True                                             
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True                                             
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False                                            
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True                                                
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True                                             
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True                                             
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False

Potential later follow ups:

  • Improve the warning messages (in particular, should provide user frames)
  • GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125115

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 8bb7f84 with merge base c1a3fcf (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Apr 28, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 5ac57facea69da09303b9708e0ed85ef567b5df5
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 8bb80b5522e80b619bdf436bd3087e8cc8514203
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 681ec87050dcac63dcdd001347f6bf2d9a55965e
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: c3a8e2925ff4f6f3b72de884b5dc472d342e1b46
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 5309d4133d565d2a0dde661c195ba52373d1e9d2
Pull Request resolved: #125115
@albanD albanD removed their request for review April 29, 2024 15:33
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 8e5355ff3d0f8f2807de3158fab14318ab8435db
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 1f97e037a55711b1db7ce2f805d6b24be3ef8f5c
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 29, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 2eeb18f9e64aaa13b1178678907d1536290a10dc
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 30, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 286bdada4b86fe1273015a0e16ab4c630e2a07fd
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 30, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 614a79b06a398a71666725dc3101b4331f081cab
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 30, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 91e4a445db2aeb871b014b541aac6f322691370c
Pull Request resolved: #125115
[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 1, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: d90e89f0eabad52bb912c2a511b0e854e7809b9c
Pull Request resolved: #125115
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label May 1, 2024
def go(t, real_t):
if isinstance(t, FakeTensor):
# NB: unconditionally overwrite
t.real_tensor = real_t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In constructor there is assert that real_tensor is not a FakeTensor.
So here it is breaking this invariant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, real_t is a real tensor :)

@IvanKobzarev
Copy link
Contributor

I think will be also good to have a separate dynamo.config.use_propagate_real_tensors for data dep solving.
To separate propagation and data dep solving to be able to use real_tensor_propagation for other functionality independently.

@ezyang
Copy link
Contributor Author

ezyang commented May 2, 2024

What I wrote in chat:

The current design for real tensor prop doesn't really work for anything else.

I know you were thinking of real tensor prop as a solution for other problems you have, but I don't really understand how these can work

@ezyang
Copy link
Contributor Author

ezyang commented May 2, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```

Potential later follow ups:

* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #125115
Approved by: https://github.com/IvanKobzarev
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request May 3, 2024
Summary:
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```

Potential later follow ups:

* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

X-link: pytorch/pytorch#125115
Approved by: https://github.com/IvanKobzarev

Reviewed By: kit1980

Differential Revision: D56915030

Pulled By: ezyang

fbshipit-source-id: f04687107bbd3e35e7bdba45998f75be6388debf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants