Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问怎么转onnx到rv1106,类似例子中给的1-3-85-80-80的输出 #11160

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#### Get model optimized for RKNN
Exporting detect/segment model with optimization for RKNN, please refer here [README_rkopt.md](./README_rkopt.md)

导出适配 RKNPU 的检测/分割模型, 请参考 [README_rkopt.md](./README_rkopt.md) 的说明

---
<br>

<div align="center">
<p>
<a align="center" href="https://ultralytics.com/yolov5" target="_blank">
Expand Down
43 changes: 43 additions & 0 deletions README_rkopt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# YOLOv5 - RKNN optimize

## Source

Base on https://github.com/ultralytics/yolov5 (v7.0) with commit id as 915bbf294bb74c859f0b41f1c23bc395014ea679



## What different

With inference result values unchanged, the following optimizations were applied:

- Optimize focus/SPPF block, getting better performance with same result
- Change output node, remove post_process from the model. (post process block in model is unfriendly for quantization)



With inference result got changed, the following optimization was applied:

- Using ReLU as activation layer instead of SiLU(Only valid when training new model)



## How to use

```
# for detection model
python export.py --rknpu --weight yolov5s.pt

# for segmentation model
python export.py --rknpu --weight yolov5s-seg.pt
```

- 'yolov5s.pt'/ 'yolov5s-seg.pt' could be replaced with your model path
- A file name "RK_anchors.txt" would be generated and it would be used for the post_process stage.
- **NOTICE: Please call with --rknpu, do not changing the default rknpu value in export.py.**



## Deploy demo

Please refer https://github.com/airockchip/rknn_model_zoo

85 changes: 78 additions & 7 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
import warnings
from pathlib import Path


# activate rknn hack
if '--rknpu' in sys.argv:
os.environ['RKNN_model_hack'] = "1"

import pandas as pd
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
Expand All @@ -68,7 +73,7 @@
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel, Segment
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
Expand Down Expand Up @@ -545,11 +550,76 @@ def run(
m.dynamic = dynamic
m.export = True

if os.getenv('RKNN_model_hack', '0') in ['1']:
from models.common import Focus
from models.common import Conv
from models.common_rk_plug_in import surrogate_focus
if isinstance(model.model[0], Focus):
# For yolo v5 version
surrogate_focous = surrogate_focus(int(model.model[0].conv.conv.weight.shape[1]/4),
model.model[0].conv.conv.weight.shape[0],
k=tuple(model.model[0].conv.conv.weight.shape[2:4]),
s=model.model[0].conv.conv.stride,
p=model.model[0].conv.conv.padding,
g=model.model[0].conv.conv.groups,
act=True)
surrogate_focous.conv.conv.weight = model.model[0].conv.conv.weight
surrogate_focous.conv.conv.bias = model.model[0].conv.conv.bias
surrogate_focous.conv.act = model.model[0].conv.act
temp_i = model.model[0].i
temp_f = model.model[0].f

model.model[0] = surrogate_focous
model.model[0].i = temp_i
model.model[0].f = temp_f
model.model[0].eval()
elif isinstance(model.model[0], Conv) and model.model[0].conv.kernel_size == (6, 6):
# For yolo v6 version
surrogate_focous = surrogate_focus(model.model[0].conv.weight.shape[1],
model.model[0].conv.weight.shape[0],
k=(3,3), # 6/2, 6/2
s=1,
p=(1,1), # 2/2, 2/2
g=model.model[0].conv.groups,
act=hasattr(model.model[0], 'act'))
surrogate_focous.conv.conv.weight[:,:3,:,:] = model.model[0].conv.weight[:,:,::2,::2]
surrogate_focous.conv.conv.weight[:,3:6,:,:] = model.model[0].conv.weight[:,:,1::2,::2]
surrogate_focous.conv.conv.weight[:,6:9,:,:] = model.model[0].conv.weight[:,:,::2,1::2]
surrogate_focous.conv.conv.weight[:,9:,:,:] = model.model[0].conv.weight[:,:,1::2,1::2]
surrogate_focous.conv.conv.bias = model.model[0].conv.bias
surrogate_focous.conv.act = model.model[0].act
temp_i = model.model[0].i
temp_f = model.model[0].f

model.model[0] = surrogate_focous
model.model[0].i = temp_i
model.model[0].f = temp_f
model.model[0].eval()

if os.getenv('RKNN_model_hack', '0') in ['1']:
if isinstance(model.model[-1], Detect):
# save anchors
print('---> save anchors for RKNN')
RK_anchors = model.model[-1].stride.reshape(3,1).repeat(1,3).reshape(-1,1)* model.model[-1].anchors.reshape(9,2)
with open('RK_anchors.txt', 'w') as anf:
# anf.write(str(model.model[-1].na)+'\n')
for _v in RK_anchors.numpy().flatten():
anf.write(str(_v)+'\n')
RK_anchors = RK_anchors.tolist()
print(RK_anchors)

if isinstance(model.model[-1], Segment):
print("export segment model for RKNPU")
model.model[-1]._register_seg_seperate(True)
else:
print("export detect model for RKNPU")
model.model[-1]._register_detect_seperate(True)

for _ in range(2):
y = model(im) # dry runs
if half and not coreml:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
shape = tuple((y[0] if (isinstance(y, tuple) or (isinstance(y, list))) else y).shape) # model output shape
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")

Expand Down Expand Up @@ -632,11 +702,11 @@ def parse_opt():
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument(
'--include',
nargs='+',
default=['torchscript'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')
parser.add_argument('--include',
nargs='+',
default=['onnx'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
parser.add_argument('--rknpu', action='store_true', help='RKNN npu platform')
opt = parser.parse_args()
print_args(vars(opt))
return opt
Expand All @@ -649,4 +719,5 @@ def main(opt):

if __name__ == "__main__":
opt = parse_opt()
del opt.rknpu
main(opt)
125 changes: 91 additions & 34 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Common modules
"""

import os
import ast
import contextlib
import json
Expand Down Expand Up @@ -51,7 +52,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
# self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

def forward(self, x):
return self.act(self.bn(self.conv(x)))
Expand Down Expand Up @@ -131,7 +133,8 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.SiLU()
# self.act = nn.SiLU()
self.act = nn.ReLU()
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

def forward(self, x):
Expand Down Expand Up @@ -200,38 +203,92 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))


class SPP(nn.Module):
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))


class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
if os.getenv('RKNN_model_hack', '0') == '0':
class SPP(nn.Module):
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
elif os.getenv('RKNN_model_hack', '0') in ['1']:
# TODO remove this hack when rknn-toolkit1/2 add this optimize rules
class SPP(nn.Module):
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
for value in k:
assert (value%2 == 1) and (value!= 1), "value in [{}] only support odd number for RKNN model hack"

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y = [x]
for maxpool in self.m:
kernel_size = maxpool.kernel_size
m = x
for i in range(math.floor(kernel_size/2)):
m = torch.nn.functional.max_pool2d(m, 3, 1, 1)
y = [*y, m]
return self.cv2(torch.cat(y, 1))


if os.getenv('RKNN_model_hack', '0') in ['0']:
class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
elif os.getenv('RKNN_model_hack', '0') in ['1']:
class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)

with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y = [x]
kernel_size = self.m.kernel_size
_3x3_stack = math.floor(kernel_size/2)
for i in range(3):
m = y[-1]
for _ in range(_3x3_stack):
m = torch.nn.functional.max_pool2d(m, 3, 1, 1)
y = [*y, m]
return self.cv2(torch.cat(y, 1))


class Focus(nn.Module):
Expand Down
30 changes: 30 additions & 0 deletions models/common_rk_plug_in.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file contains modules common to various models

import torch
import torch.nn as nn
from models.common import Conv


class surrogate_focus(nn.Module):
# surrogate_focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(surrogate_focus, self).__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)

with torch.no_grad():
self.convsp = nn.Conv2d(3, 12, (2, 2), groups=1, bias=False, stride=(2, 2))
self.convsp.weight.data = torch.zeros(self.convsp.weight.shape).float()
for i in range(4):
for j in range(3):
ch = i*3 + j
if ch>=0 and ch<3:
self.convsp.weight[ch:ch+1, j:j+1, 0, 0] = 1
elif ch>=3 and ch<6:
self.convsp.weight[ch:ch+1, j:j+1, 1, 0] = 1
elif ch>=6 and ch<9:
self.convsp.weight[ch:ch+1, j:j+1, 0, 1] = 1
elif ch>=9 and ch<12:
self.convsp.weight[ch:ch+1, j:j+1, 1, 1] = 1

def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(self.convsp(x))
3 changes: 2 additions & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kern
self.m = nn.ModuleList([
nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU()
# self.act = nn.SiLU()
self.act = nn.ReLU()

def forward(self, x):
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
Expand Down