Skip to content

Commit

Permalink
Add option to quantize per-tensor (#12516)
Browse files Browse the repository at this point in the history
* Add option to quantize per-tensor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people committed Dec 18, 2023
1 parent f33d42d commit 63555c8
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):


@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
def export_tflite(keras_model, im, file, int8, per_tensor, data, nms, agnostic_nms,
prefix=colorstr('TensorFlow Lite:')):
# YOLOv5 TensorFlow Lite export
import tensorflow as tf

Expand All @@ -469,6 +470,8 @@ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=c
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
converter.experimental_new_quantizer = True
if per_tensor:
converter._experimental_disable_per_channel = True
f = str(file).replace('.pt', '-int8.tflite')
if nms or agnostic_nms:
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
Expand Down Expand Up @@ -713,6 +716,7 @@ def run(
keras=False, # use Keras
optimize=False, # TorchScript: optimize for mobile
int8=False, # CoreML/TF INT8 quantization
per_tensor=False, # TF per tensor quantization
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version
Expand Down Expand Up @@ -798,7 +802,14 @@ def run(
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = export_pb(s_model, file)
if tflite or edgetpu:
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
f[7], _ = export_tflite(s_model,
im,
file,
int8 or edgetpu,
per_tensor,
data=data,
nms=nms,
agnostic_nms=agnostic_nms)
if edgetpu:
f[8], _ = export_edgetpu(file)
add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
Expand Down Expand Up @@ -837,6 +848,7 @@ def parse_opt(known=False):
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--int8', action='store_true', help='CoreML/TF/OpenVINO INT8 quantization')
parser.add_argument('--per-tensor', action='store_true', help='TF per-tensor quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=17, help='ONNX: opset version')
Expand Down

0 comments on commit 63555c8

Please sign in to comment.