Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RegionProposalNetwork can't be AOTInductor compiled with dynamic batch size #8285

Open
rbavery opened this issue Mar 3, 2024 · 3 comments
Open

Comments

@rbavery
Copy link

rbavery commented Mar 3, 2024

馃悰 Describe the bug

this is a cross post of pytorch/pytorch#121036

Just raising it here to notify the maintainers that I'm going to take a crack at fixing the RegionProposalNetwork and potentially other modules to be either traceable, AOTInductor compileable, or both. Are there any current efforts in this direction I should be aware of?

For AOTInductor I think this will at least involve changing the AnchorGenerator, which has a method that mutates an anchor attribute to instead return anchor values.

To support tracing, my plan is to address each TracerWarning (see below). First I'll be looking to remove the iteration over tensors in ImageList that prevent the model from generalziing after tracing.

[/opt/workspace/./satlas-src/satlas/model/model.py:438](http://127.0.0.1:8888/satlas-src/satlas/model/model.py#line=437): TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  image_sizes = [(image.shape[1], image.shape[2]) for image in images]
[/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py:166](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py#line=165): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
[/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py:168](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/ops/boxes.py#line=167): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
[/opt/conda/lib/python3.10/site-packages/torch/__init__.py:1560](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torch/__init__.py#line=1559): TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
[/opt/workspace/./satlas-src/satlas/model/model.py:537](http://127.0.0.1:8888/satlas-src/satlas/model/model.py#line=536): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  losses = {'base': torch.tensor(0, device=device, dtype=torch.float32)}
[/opt/workspace/./satlas-src/satlas/model/model.py:850](http://127.0.0.1:8888/satlas-src/satlas/model/model.py#line=849): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  losses = torch.tensor(0, device=batch_tensor.device, dtype=torch.float32)

Versions

I'm using the nightlies, see pytorch/pytorch#121036

@NicolasHug
Copy link
Member

Thanks for the report @rbavery

Are there any current efforts in this direction I should be aware of?

No, we haven't been looking at the RPN's support for torch.compile yet.

I think this will at least involve changing the AnchorGenerator

Just note that a lot of this code is public and changing the behaviour e.g. the expected intput/output would technically be breaking backward compatibility. So adding support for AOT while still preserving BC may not be a trivial task.

Beyond the RPN, what model specifically are you interested in tracing?

@rbavery
Copy link
Author

rbavery commented Mar 4, 2024

Got it, I initially went with supporting TorchScript scripting since it seemed easier and would only require adding type annotations. I've made edits to this model which uses a SWIN Transformer backbone, an FPN, and a Faster RCNN head:

https://github.com/allenai/satlas/blob/main/configs/satlas_explorer_marine_infrastructure.txt
https://github.com/allenai/satlas/blob/main/satlas/model/model.py

So far I addressed torchscript scripting issues with type annotations in the Satlas model source.

the first issue I hit was with torchvision is here:

RuntimeError: 
Module 'GeneralizedRCNNTransform' has no attribute 'image_mean' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type. List trace inputs must have elements. Its type was inferred; try adding a type annotation for the attribute.):
  File "[/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py", line 167](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py#line=166)
            )
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
                               ~~~~~~~~~~~~~~~ <--- HERE
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) [/](http://127.0.0.1:8888/) std[:, None, None]
'GeneralizedRCNNTransform.normalize' is being compiled since it was called from 'GeneralizedRCNNTransform.forward'
  File "[/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py", line 141](http://127.0.0.1:8888/opt/conda/lib/python3.10/site-packages/torchvision/models/detection/transform.py#line=140)
            if image.dim() != 3:
                raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
            image = self.normalize(image)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            image, target_index = self.resize(image, target_index)
            images[i] = image

and I've had some trouble addressing this with typing, since the class attribute is already typed, I'm not sure how to enable Torschript scripting to understand this attribute can be either a List[float] or None. I might need to make code modifications. I'll try to do so in a way that preserves backwards compat and leaves passing test and PR if it is helpful.

@rbavery
Copy link
Author

rbavery commented Mar 13, 2024

I was able to get torch scripting to work by refactoring the Satlas source code, mostly by adding typing, removing control flow in some spots, and replacing the use of complex python data structures containing tensors with plain tensors. inference on dynamic batches appears to work without error. No changes to torchvision were needed.

but not AOTInductor unfortunately. I made some progress forking torchvision and trying to remove the use of ImageList and other python data structures, remove control flow (often making very hard assumptions about the input data), replace python indexing with torch.narrow, etc. But I still ran into unbacked symint issues when the NMS step is applied in the RPN, which I wasn't sure how to get around the fact that NMS is data-dependent and can't be made un-data dependent. If it' shelpful, I tried to document how I made progress here pytorch/pytorch#121036

Both methods for exporting were relatively painful. I'm hoping that AOTInductor comes up with a solution for handling data-dependent shapes, or making the process to write code that handles data-dependent shapes easier. I realize it is early days for AOTInductor still, but documentation would go a long way. I'd be happy to contribute but still feel fairly new to the process of handing data dependent shapes.

@rbavery rbavery changed the title RegionProposalNetwork can't be traced or AOTInductor compiled with dynamic batch size RegionProposalNetwork can't be AOTInductor compiled with dynamic batch size Mar 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants