Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCNN: Unsupported unsqueeze axes ! #118

Open
MHGL opened this issue Jul 8, 2021 · 1 comment
Open

NCNN: Unsupported unsqueeze axes ! #118

MHGL opened this issue Jul 8, 2021 · 1 comment

Comments

@MHGL
Copy link
Contributor

MHGL commented Jul 8, 2021

馃悰 Bug

I get this error while convert yolov5 Focus module to ncnn

To Reproduce

Steps to reproduce the behavior:

  1. code example
import torch

# init module
class Focus(torch.nn.Module):
    def __init__(self):
        super(Focus, self).__init__()
        ...

    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        x = torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
        return x

torch_model = Focus().eval()

# torch.onnx.export
torch.onnx.export(torch_model,
        torch.randn(1, 3, 224, 224),
        "./tmp.onnx",
        input_names=["inputs"],
        output_names=["outputs"],
        dynamic_axes={"inputs": {0: "batch", 2: "height", 3: "width"}, "outputs": {0: "batch", 1: "class", 2: "height", 3: "width"}},
        opset_version=11,
        export_params=True)

# onnx simplify
import os
import onnx
from onnxsim import simplify

onnx_file = os.path.join(os.getcwd(), "tmp.onnx")
model_op, check_ok = simplify(onnx_file, 
        check_n=3, 
        perform_optimization=True, 
        skip_fuse_bn=True,  
        skip_shape_inference=False, 
        input_shapes={"inputs": (1, 3, 224, 224)}, 
        skipped_optimizers=None, 
        dynamic_input_shape={"inputs": {0: "batch", 2: "height", 3: "width"}, "outputs": {0: "batch", 1: "class", 2: "height", 3: "width"}})
onnx.save(model_op, "./tmp.onnx")

# onnx -> ncnn
# !!!
# you should build onnx2ncnn binary file first
os.system("/bin/onnx2ncnn {} tmp.params tmp.bin".format(onnx_file))
  1. stack traces
Checking 0/3...
Checking 1/3...
Checking 2/3...
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !

Expected behavior

Environment

  • PyTorch Version: 1.9.0
  • OS (e.g., MacOS, Linux): Ubuntu20.04 LTS
  • How you install python (anaconda, virtualenv, system): miniconda
  • python version (e.g. 3.7): 3.8.5
  • any other relevant information:
    • gpu: GeForce GTX 1650
    • driver: Driver Version: 460.80
    • CUDA: CUDA Version: 11.2
@MHGL
Copy link
Contributor Author

MHGL commented Jul 8, 2021

rewrite Focus

from torch.nn import functional as F

class Focus(torch.nn.Module):
    def __init__(self):
        super(Focus, self).__init__()
        ...

    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        # part-1
        gain = torch.tensor([[1, 0], [0, 0]])
        filters = torch.zeros(3, 1, 2, 2) + gain
        x1 = F.conv2d(x, filters, stride=2, groups=3)
        # part-2
        gain = torch.tensor([[0, 0], [1, 0]])
        filters = torch.zeros(3, 1, 2, 2) + gain
        x2 = F.conv2d(x, filters, stride=2, groups=3)
        # part-3
        gain = torch.tensor([[0, 1], [0, 0]])
        filters = torch.zeros(3, 1, 2, 2) + gain
        x3 = F.conv2d(x, filters, stride=2, groups=3)
        # part-4
        gain = torch.tensor([[0, 0], [0, 1]])
        filters = torch.zeros(3, 1, 2, 2) + gain
        x4 = F.conv2d(x, filters, stride=2, groups=3)
        return torch.cat([x1, x2, x3, x4], 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant