Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion to torchscript or ONNX #24

Open
drewm1980 opened this issue Jan 21, 2021 · 12 comments
Open

Conversion to torchscript or ONNX #24

drewm1980 opened this issue Jan 21, 2021 · 12 comments

Comments

@drewm1980
Copy link

drewm1980 commented Jan 21, 2021

I'm working on optimizing my model inference, trying conversion to torchscript as a first step. When I call torch.jit.script() on my model, I hit:

name = '_weights_ranges', item = {('irrep_0,0', 'regular'): (0, 288)}

    def infer_type(name, item):
        # The forward function from Module is special; never use this annotations; we
        # need to infer type directly using JIT.  I originally wanted to write
        # this test as isinstance(class_annotations[name], Callable) but
        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
        # is also true!
        if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
            attr_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
        elif isinstance(item, torch.jit.Attribute):
            attr_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
        else:
>           attr_type = torch._C._jit_try_infer_type(item)
E           RuntimeError: Cannot create dict for key type '(str, str)', only int, float, Tensor and string keys are supported

This pytorch code resides here:
https://github.com/pytorch/pytorch/blob/22902b9242853a4ce319e7c5c4a1c94bc00ccb7a/torch/jit/_recursive.py#L126

Torch can't trace through a Dict[Tuple[str,str],_], which is used here:

coefficients = weights[self._weights_ranges[io_pair][0]:self._weights_ranges[io_pair][1]]

My goal is to get the model to run as fast as possible on NVIDIA hardware, probably using tensorrt. Is there another known-good conversion path?

The above error was with torch 1.6.0, e2cnn v0.1.

@drewm1980
Copy link
Author

@drewm1980
Copy link
Author

Update: all the above was with torch.jit.script() which is actually a sort of transpiler, rather than a tracer.

I am progressing with torch.jit.trace().

@drewm1980 drewm1980 changed the title Conversion to torchscript fails Conversion to torchscript or ONNX Jan 22, 2021
@Gabri95
Copy link
Collaborator

Gabri95 commented Jan 22, 2021

Hi @drewm1980

Unfortunately, I have not much experience with torchscript and torch.jit, so I am not sure how to solve that problem.

However, I recommend to check the .export() function: https://quva-lab.github.io/e2cnn/api/e2cnn.nn.html
This allows you to convert most of your equivariant networks into a simpler pure Pytorch model in a single call!
Then, you should be able to work with the converted model as you would normally do in Pytorch.

Unfortunately, not all equivariant modules support this functionality yet, but the most commonly used ones do.
If you specifically need one of them which is not implemented yet, feel free to open a pull request or let me know, so I will try to implement it.

Let me know if this helps!

Gabriele

@drewm1980
Copy link
Author

Thanks for pointing that out; I'll see if I can adapt my code to use it (I currently inherit from torch.nn.Module). One obstacle might be that my code uses torch.nn.ModuleList. There are no hits for "ModuleList" in the e2cnn repo, so I'm guessing there is no equivalent.

@Gabri95
Copy link
Collaborator

Gabri95 commented Jan 22, 2021

I see; anyways, if your torch.nn.Module contains some EquivariantModules, you could just manually call the .export() method of the equivariant submodules.
For instance if you have a torch.nn.ModuleList which contains EquivariantModules, after training, you could replace every module in the list with module.export().

@drewm1980
Copy link
Author

My Unet is structured as a module that depends on two other modules. They all call operators in their forward() methods that don't have Module subclass equivalents, and they're also statically typed. To do this "right" it seems like I would need to define non-e2cnn versions of each of those, along with ports of the forward() functions with correct static type signatures. As you suggest, I could try recursively export()'ing all of the owned modules in place. The dynamic type signatures will deviate from the static ones if I do that, but maybe something is possible with generics.

Is there some way to build a DAG of modules in pytorch that I'm missing, or is subclassing (along with the consequences for composability I'm hitting here) really the only way?

Are there any inference time computations or abstractions that you're certain wouldn't just get pruned out by the model optimizer anyway? If it's probably going to boil down to the same inference network, I'll skip .export() for now and just hope for the best with torchscript tracing and the tensorrt compiler.

@Gabri95
Copy link
Collaborator

Gabri95 commented Jan 28, 2021

Are there any inference time computations or abstractions that you're certain wouldn't just get pruned out by the model optimizer anyway?

After setting .eval() mode, my code should be a bit optimized such that useless computations are skipped but the code of the library is still carrying a lot of additional structure and data (and many asserts I used to implement a manual form of static typing, e.g. to ensure the tensors passed to a module have the right FieldType) which you may not want to have at deployment.

Unfortunately, I don't know enough about torchscript work to be able to answer your question properly :(

I would still reccommend using the .export() method. That seems the cleanest and safest option to me.
I understand this could be problematic if you are using an equivariant module inside a normal one, since you would have to adapt the .forward() method of the larger model to not wrap tensors in GeometricTensors.

I don't see a simple solution to this for the moment though.

Let me know if you find some nicer solutions

Best,
Gabriele

P.S.: In some future release, I am thinking of relaxing the strongly typed structure of the equivariant modules, such that they can accept both Pytorch tensors and GeometricTensors, such that one doesn't need to wrap them necessarily.
In this case, I would assume the user is passing a tensor which is transforming as the module expects it to transform but then I can not automatically check the equivariance. This is approach is probably more "pythonic".
It would also allow you to change only the equivariant modules in you model with .export() , without the need to adapt the forward() pass of the model.

@drewm1980
Copy link
Author

I'm back working on getting my model into tensorrt... Currently trying torch.jit.trace() followed by torch.onnx.export().

My current blocker is that ONNX supports einsum starting in opset 12, but tensorrt only supports up to opset 11. My model isn't calling einsum, so I need to look into how hard it would be to convert the einsum calls in e2cnn into opset 11 operators.

@drewm1980
Copy link
Author

Still working on tracing which einsum calls in e2cnn I'm actually hitting, but these seem like some likely candidates:

data = torch.einsum("oi,bihw->bohw", (rho, self.tensor.contiguous())).contiguous()

return torch.einsum('boi...,kb->koi...', self.sampled_basis, weights) #.transpose(1, 2).contiguous()

@drewm1980
Copy link
Author

TensorRT's list of supported ONNX operators is here:
https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html

@Gabri95
Copy link
Collaborator

Gabri95 commented Mar 29, 2021

Hi @drewm1980

Nice to hear from you again!

I think your problem is the einsum in

return torch.einsum('boi...,kb->koi...', self.sampled_basis, weights) #.transpose(1, 2).contiguous()

The one inside GeometricTensor is not usually called inside a neural network.

Anyways, if you use .export() , all usages of einsum (at least inside R2Conv) should disappear since einsum is used only to compute the filters from the learnable weights, but you don't need to do this anymore at inference time.

Regarding .export(), I am not planning to remove it, sorry if my last message was not clear.

Best,
Gabriele

@drewm1980
Copy link
Author

drewm1980 commented Mar 29, 2021

I was actually looking at the code for the wrong branch; your comment was fine :) Thanks for the confirmation that I'm probably going in the right direction!

Cheers,
Andrew

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants