Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export compatible with OpenCV #462

Closed
wants to merge 1 commit into from

Conversation

dkurt
Copy link

@dkurt dkurt commented Mar 28, 2020

This PR has changes which let to export trained model to ONNX format compatible with OpenCV (PR opencv/opencv#16925 is required)

Convert to ONNX

python3 export_to_onnx.py --model ssd300_mAP_77.43_v2.pth

Run with OpenCV

import numpy as np
import cv2 as cv

net = cv.dnn.readNet('ssd.onnx')

img = cv.imread('example.jpg')
rows, cols = img.shape[0:2]
inp = cv.dnn.blobFromImage(img, 1.0, (300, 300), mean=(123, 117, 104), swapRB=True)

net.setInput(inp)
out = net.forward()

for detection in out[0,0,:,:]:
    score = float(detection[2])
    if score > 0.5:
        xmin, ymin, xmax, ymax = [int(v) for v in detection[3:] * [cols, rows, cols, rows]]
        cv.rectangle(img, (xmin, ymin), (xmax, ymax), (23, 230, 210), thickness=2)

cv.imshow('Object detection', img)
cv.waitKey()

Or run with OpenVINO

python3 /opt/intel/openvino/deployment_tools/model_optimizer/mo_onnx.py \
  --input_model ssd.onnx \
  --mean_values [123,117,104] \
  --reverse_input_channels \
  --input_shape [1,3,300,300]

res

@luoduo21
Copy link

@dkurt Hi, do you know how to export the model with prior box node. If I add the prior box node, the error will appear:
output of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.

@dkurt
Copy link
Author

dkurt commented Mar 30, 2020

Hi, @luoduo21, can you please show the exact source branch version to reproduce? In this PR priors are stored in the constant node.

@luoduo21
Copy link

luoduo21 commented Mar 31, 2020

@dkurt
self.priors = self.priorbox.forward(), volatile=True)
output = ( self.priors, loc.view(loc.size(0), -1, 4), # loc preds self.softmax(conf.view(-1, self.num_classes)), # conf preds )

the export code:

torch.onnx.export(net, example, model_path, verbose=True, input_names=['input'], output_names=['priors','boxes', 'scores'],export_params=True)

Maybe onnx doesn't support exporting constant node?

@dkurt
Copy link
Author

dkurt commented Apr 1, 2020

@luoduo21, Can you please share why you need to export it as output? I mean do you want to run SSD and proposed branch not works for your model? You don't have to export raw priors+boxes+scores but you can export post-processed outputs from Detect layer (like in this PR).

@luoduo21
Copy link

luoduo21 commented Apr 2, 2020

@dkurt I want to convert the onnx model to the ncnn model. I tried to add a Detect layer to export the post-processing directly,Although it can be successful,there are some problems in the process of onnx - > ncnn. So I want to export the three nodes directly。

@LucasVandroux
Copy link

@dkurt Thanks a lot for providing the code to convert the SSD model to the ONNX format. Unfortunately, I am not able to run your script export_to_onnx.py. Could you please help me understand what I am missing out?

Environment

  • Ubuntu (WSL): 18.04
  • PyTorch: 1.5.0+cpu
  • ONNX: 1.7.0

Issues

  1. Impossible to find the coco_labels.txt file. Fixed using Corrected path to coco_labels #400
  2. ONNX: RuntimeError: No Op registered for DetectionOutput with domain_version of 9 (see full output below). I didn't manage to solve that one but it looks that it is coming from the changes you did in ./layers/functions/detection.py in the symbolic function.
Traceback (most recent call last):
  File "export_to_onnx.py", line 43, in <module>
    save_model(net, input, args.output)
  File "export_to_onnx.py", line 24, in save_model
    onnx_model_pb = export_to_string(model, input)
  File "export_to_onnx.py", line 19, in export_to_string
    torch.onnx.export(model, inputs, f, export_params=True, opset_version=version)
  File "/mnt/c/Users/lucas/Documents/ssd.pytorch/.env/lib/python3.6/site-packages/torch/onnx/__init__.py", line 168, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/mnt/c/Users/lucas/Documents/ssd.pytorch/.env/lib/python3.6/site-packages/torch/onnx/utils.py", line 69, in export
    use_external_data_format=use_external_data_format)
  File "/mnt/c/Users/lucas/Documents/ssd.pytorch/.env/lib/python3.6/site-packages/torch/onnx/utils.py", line 515, in _export
    _check_onnx_proto(proto)
RuntimeError: No Op registered for DetectionOutput with domain_version of 9

==> Context: Bad node spec: input: "262" input: "281" input: "282" output: "283" name: "DetectionOutput_177" op_type: "DetectionOutput" attribute { name: "background_label_id" i: 0 type: INT } attribute { name: "code_type" s: "CENTER_SIZE" type: STRING } attribute { name: "confidence_threshold" f: 0.01 type: FLOAT } attribute { name: "keep_top_k" i: 200 type: INT } attribute { name: "nms_threshold" f: 0.45 type: FLOAT } attribute { name: "num_classes" i: 21 type: INT } attribute { name: "share_location" i: 1 type: INT } attribute { name: "top_k" i: 200 type: INT } attribute { name: "variance_encoded_in_target" i: 0 type: INT }

@dkurt
Copy link
Author

dkurt commented Jun 9, 2020

@LucasVandroux, thanks for feedback! That looks very strage. Can you please try ONNX of version 1.6.0?

@LucasVandroux
Copy link

LucasVandroux commented Jun 9, 2020

@dkurt Thanks for your fast answer! I have the same problem with ONNX 1.6.0. Here is the full output:

python export_to_onnx.py --model ../models/ssd300_mAP_77.43_v2.pth

/mnt/c/Users/lucas/Documents/ssd.pytorch/ssd.py:34: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
  self.priors = Variable(self.priorbox.forward(), volatile=True)
/mnt/c/Users/lucas/Documents/ssd.pytorch/layers/functions/detection.py:41: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  prior_data = prior_data[:prior_data.shape[0] // 2]
/mnt/c/Users/lucas/Documents/ssd.pytorch/layers/functions/detection.py:50: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for i in range(num):
/mnt/c/Users/lucas/Documents/ssd.pytorch/layers/functions/detection.py:58: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if scores.size(0) == 0:
Traceback (most recent call last):
  File "export_to_onnx.py", line 43, in <module>
    save_model(net, input, args.output)
  File "export_to_onnx.py", line 24, in save_model
    onnx_model_pb = export_to_string(model, input)
  File "export_to_onnx.py", line 19, in export_to_string
    torch.onnx.export(model, inputs, f, export_params=True, opset_version=version)
  File "/mnt/c/Users/lucas/Documents/ssd.pytorch/.env/lib/python3.6/site-packages/torch/onnx/__init__.py", line 168, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/mnt/c/Users/lucas/Documents/ssd.pytorch/.env/lib/python3.6/site-packages/torch/onnx/utils.py", line 69, in export
    use_external_data_format=use_external_data_format)
  File "/mnt/c/Users/lucas/Documents/ssd.pytorch/.env/lib/python3.6/site-packages/torch/onnx/utils.py", line 515, in _export
    _check_onnx_proto(proto)
RuntimeError: No Op registered for DetectionOutput with domain_version of 9

==> Context: Bad node spec: input: "262" input: "281" input: "282" output: "283" name: "DetectionOutput_177" op_type: "DetectionOutput" attribute { name: "background_label_id" i: 0 type: INT } attribute { name: "code_type" s: "CENTER_SIZE" type: STRING } attribute { name: "confidence_threshold" f: 0.01 type: FLOAT } attribute { name: "keep_top_k" i: 200 type: INT } attribute { name: "nms_threshold" f: 0.45 type: FLOAT } attribute { name: "num_classes" i: 21 type: INT } attribute { name: "share_location" i: 1 type: INT } attribute { name: "top_k" i: 200 type: INT } attribute { name: "variance_encoded_in_target" i: 0 type: INT }

The full list of python packages I am using is as follow:

pip list

future (0.18.2)
numpy (1.18.5)
onnx (1.6.0)
opencv-python (4.2.0.34)
Pillow (7.1.2)
pip (9.0.1)
pkg-resources (0.0.0)
protobuf (3.12.2)
setuptools (39.0.1)
six (1.15.0)
torch (1.5.0+cpu)
torchvision (0.6.0+cpu)
typing-extensions (3.7.4.2)

And here the end of my git log:

git log

commit 1c02c5f68a325a3f8e5ec27341c778b4cf0d59af (HEAD -> pr/462)
Author: Dmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Date:   Sat Mar 28 18:06:52 2020 +0300

    ONNX export compatible with OpenCV

commit 5b0b77faa955c1917b0c710d770739ba8fbff9b7 (origin/master, origin/HEAD, master)
Author: Degroot <alexdg@amazon.com>
Date:   Mon Mar 11 14:43:06 2019 -0700

    Update .dim() to .size(0) in eval.py for compatibility with latest .dim() functionality

commit ad98ca4f83ebd7e690825b8bb067419030bda992
Author: Miguel Morales <mimoralea@gmail.com>
Date:   Mon Mar 11 15:35:33 2019 -0600

    Update README.md (#235)

    Link to webyclip.com is broken

Would some other information help you as well?

@dkurt
Copy link
Author

dkurt commented Jun 9, 2020

So, I tried the following steps:

git clone https://github.com/dkurt/ssd.pytorch --branch opencv_support --depth 1
cd ssd.pytorch
wget https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
python3 export_to_onnx.py --model ssd300_mAP_77.43_v2.pth

with torch==1.4.0 and onnx==1.6.0

It reports some warnings but produces .onnx file

Also tried torch=1.5.0 and can confirm that it throws

RuntimeError: No Op registered for DetectionOutput with domain_version of 9

==> Context: Bad node spec: input: "262" input: "281" input: "282" output: "283" name: "DetectionOutput_177" op_type: "DetectionOutput" attribute { name: "background_label_id" i: 0 type: INT } attribute { name: "code_type" s: "CENTER_SIZE" type: STRING } attribute { name: "confidence_threshold" f: 0.01 type: FLOAT } attribute { name: "keep_top_k" i: 200 type: INT } attribute { name: "nms_threshold" f: 0.45 type: FLOAT } attribute { name: "num_classes" i: 21 type: INT } attribute { name: "share_location" i: 1 type: INT } attribute { name: "top_k" i: 200 type: INT } attribute { name: "variance_encoded_in_target" i: 0 type: INT }

Thanks for note! So, for current state or PR you may try torch 1.4.0 only

@LucasVandroux
Copy link

@dkurt thank you, the problem is indeed coming from the pytorch version and one needs to use torch==1.4.0. I also needed to apply the fix from #400 .

@LucasVandroux
Copy link

@dkurt thanks again for providing the code to convert the ONNX model to OpenVINO. I used it, and it worked perfectly with openvino_2020.2.120.

For future developers who would like to use this model in their application, please note that the output label index corresponds to VOC and not to COCO.

@LucasVandroux
Copy link

For those who might be interested in getting directly the ONNX file and the OpenVINO files, I uploaded them here: link.

@TamarVolcani
Copy link

where is the file export_to_onnx.py? I don't see it in the repository...

@LucasVandroux
Copy link

LucasVandroux commented Dec 20, 2020

where is the file export_to_onnx.py? I don't see it in the repository...

@TamarVolcani as this merge request hasn't yet been merged into the original repository, you will not find it there but in the fork of the repository: https://github.com/dkurt/ssd.pytorch/blob/opencv_support/export_to_onnx.py

@shiyuetianqiang
Copy link

So, I tried the following steps:

git clone https://github.com/dkurt/ssd.pytorch --branch opencv_support --depth 1
cd ssd.pytorch
wget https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
python3 export_to_onnx.py --model ssd300_mAP_77.43_v2.pth

with torch==1.4.0 and onnx==1.6.0

It reports some warnings but produces .onnx file

Also tried torch=1.5.0 and can confirm that it throws

RuntimeError: No Op registered for DetectionOutput with domain_version of 9

==> Context: Bad node spec: input: "262" input: "281" input: "282" output: "283" name: "DetectionOutput_177" op_type: "DetectionOutput" attribute { name: "background_label_id" i: 0 type: INT } attribute { name: "code_type" s: "CENTER_SIZE" type: STRING } attribute { name: "confidence_threshold" f: 0.01 type: FLOAT } attribute { name: "keep_top_k" i: 200 type: INT } attribute { name: "nms_threshold" f: 0.45 type: FLOAT } attribute { name: "num_classes" i: 21 type: INT } attribute { name: "share_location" i: 1 type: INT } attribute { name: "top_k" i: 200 type: INT } attribute { name: "variance_encoded_in_target" i: 0 type: INT }

Thanks for note! So, for current state or PR you may try torch 1.4.0 only

Hi, thanks for your work for converting the SSD to onnx model.
Follow your suggestions, I also adjust my corresponding softwares version, and it does work for converting proposed model like ssd300_mAP_77.43_v2.pth.
But when I employ your repo "https://github.com/dkurt/ssd.pytorch/tree/opencv_support" to train a new ssd model, and employ the convert file, it fails.
The error is shown as follows:
" File "/home/llq/Desktop/ssd/2ssd_another/layers/box_utils.py", line 187, in nms
keep = scores.new(scores.size(0)).zero_().long()
TypeError: expected Float (got Long)"
I have tried to fix it, however, I failed, could you please provide some advice?

@abuvaneswari
Copy link

Hello,

Thank you for the script to convert the model and the instructions. I am able to do the conversion to ONNX (opsets 9, 10 and 11) as well as to Intel IR formats.

However, I am not able to inference the converted ONNX models (opsets 9, 10, 11) with onnxruntime. I tried both 1.4.0 and 1.6.0 versions of onnxruntime.

Here is the error message that I get:
File "/home/buvana/base_ort.py", line 92, in
m = ONNXNet(model)
File "/home/buvana/base_ort.py", line 10, in init
self.ort_session_f = ort.InferenceSession(onnx_path_f)
File "/home/buvana/venv/lib64/python3.6/site-packages/onnxruntime/capi/session.py", line 158, in init
self._load_model(providers or [])
File "/home/buvana/venv/lib64/python3.6/site-packages/onnxruntime/capi/session.py", line 177, in _load_model
self._sess.load_model(providers)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Error in Node: : No Op registered for DetectionOutput with domain_version of 10

Please provide your suggestions.

@dkurt
Copy link
Author

dkurt commented Jan 13, 2021

@abuvaneswari, Proposed solution won't work with ONNXRuntime. It's only for OpenCV and OpenVINO. Both of them can import ONNX model directly.

@mohadese-yousefi
Copy link

So, I tried the following steps:

git clone https://github.com/dkurt/ssd.pytorch --branch opencv_support --depth 1
cd ssd.pytorch
wget https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
python3 export_to_onnx.py --model ssd300_mAP_77.43_v2.pth

with torch==1.4.0 and onnx==1.6.0
It reports some warnings but produces .onnx file
Also tried torch=1.5.0 and can confirm that it throws

RuntimeError: No Op registered for DetectionOutput with domain_version of 9

==> Context: Bad node spec: input: "262" input: "281" input: "282" output: "283" name: "DetectionOutput_177" op_type: "DetectionOutput" attribute { name: "background_label_id" i: 0 type: INT } attribute { name: "code_type" s: "CENTER_SIZE" type: STRING } attribute { name: "confidence_threshold" f: 0.01 type: FLOAT } attribute { name: "keep_top_k" i: 200 type: INT } attribute { name: "nms_threshold" f: 0.45 type: FLOAT } attribute { name: "num_classes" i: 21 type: INT } attribute { name: "share_location" i: 1 type: INT } attribute { name: "top_k" i: 200 type: INT } attribute { name: "variance_encoded_in_target" i: 0 type: INT }

Thanks for note! So, for current state or PR you may try torch 1.4.0 only

Hi, thanks for your work for converting the SSD to onnx model.
Follow your suggestions, I also adjust my corresponding softwares version, and it does work for converting proposed model like ssd300_mAP_77.43_v2.pth.
But when I employ your repo "https://github.com/dkurt/ssd.pytorch/tree/opencv_support" to train a new ssd model, and employ the convert file, it fails.
The error is shown as follows:
" File "/home/llq/Desktop/ssd/2ssd_another/layers/box_utils.py", line 187, in nms
keep = scores.new(scores.size(0)).zero_().long()
TypeError: expected Float (got Long)"
I have tried to fix it, however, I failed, could you please provide some advice?

I converted this line:
keep = scores.new(scores.size(0)).zero_().long()
to
keep = torch.zeros(scores.size(), dtype=torch.int64)
and solved.

@dkurt dkurt closed this Apr 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants