Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

"PyTorch convert function for op '_convolution_mode' not implemented" #1363

Closed
lauriebyrum opened this issue Dec 7, 2021 · 3 comments 路 Fixed by #1365
Closed

"PyTorch convert function for op '_convolution_mode' not implemented" #1363

lauriebyrum opened this issue Dec 7, 2021 · 3 comments 路 Fixed by #1365
Assignees
Labels
missing layer type Unable to convert a layer type from the relevant framework PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@lauriebyrum
Copy link

馃尡 Describe your Feature Request

Pytorch has added 'same' padding for convolutions, but if you try to convert such models to coreml, you get the error "PyTorch convert function for op '_convolution_mode' not implemented".

There is a workaround discussed in recent comments on pytorch/pytorch#3867, but it would be great if we didn't have to change our model code for this, especially since tf->coreml supports this padding.

@lauriebyrum
Copy link
Author

a little more of the story: i had a model that i had trained in tensorflow over a year ago. At that time, i used tfcoreml to convert it and had a lovely fast running model. Recently, i was trying to convert a new variant of the model to coreml and found it to be 2x as slow. This is because coremltools isn't dealing with conversion to hwc. SO: i had already mostly moved my model to pytorch, so I have been trying to get the year-old architecture to convert from pytorch to coreml and run as fast as the tfcoreml produced model. I have been having to work around this '_convolution_mode' issue in order to convert at all. Today, I used a giant sledgehammer and forced pytorch->coreml to respect same padding instead of working around it. Finally, my model runs fast. Without the sledgehammer, it runs 100ms slower!

Long story short: there are workarounds to this issue, but they come with a very expensive cost. Please implement this!

@lauriebyrum
Copy link
Author

We tried adding "_convolution_mode" to the alias list here https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py#L561. Then it complained about having strings as input. I think it was expecting padding ints. "ValueError: Op "input.305" (op_type: conv) Input pad="input.305_pad_0" expects integer tensor but got tensor[2,str]"

@TobyRoseman TobyRoseman added missing layer type Unable to convert a layer type from the relevant framework PyTorch (traced) labels Dec 9, 2021
@TobyRoseman
Copy link
Collaborator

I can reproduce this problem with the following code:

import coremltools as ct
import torch
from torch import nn

model = nn.Conv2d(1,1,3, padding="same")
x = torch.randn(1,1,3,3)
traced_model = torch.jit.trace(model.eval(), x)
ct_model = ct.convert(traced_model, inputs=[ct.TensorType(shape=x.shape)])

@lauriebyrum - thanks for reporting this issue. This does seem like a high priority issue. I'll look into getting at least some level of support for PyTorch convolutions with same padding.

@TobyRoseman TobyRoseman self-assigned this Dec 9, 2021
@TobyRoseman TobyRoseman added the triaged Reviewed and examined, release as been assigned if applicable (status) label Dec 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
missing layer type Unable to convert a layer type from the relevant framework PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants