Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model to TensorRT #391

Open
hackgyh opened this issue Apr 27, 2024 · 0 comments
Open

Model to TensorRT #391

hackgyh opened this issue Apr 27, 2024 · 0 comments

Comments

@hackgyh
Copy link

hackgyh commented Apr 27, 2024

I am trying to deploy a model to an NVIDIA Jetson Orin and need to perform a TensorRT conversion. However, there were issues during the conversion process.

The conversion process consists of two steps. The first step is converting the model to ONNX format. Below is the code I wrote, using the SuperPoint model as an example.

import onnx
import torch
from torch import nn
from hloc.utils.base_model import dynamic_load
from hloc import extractors, extract_features
import numpy as np
import cv2
import PIL
from third_party.SuperGluePretrainedNetwork.models import superpoint

def generate(conf, onnx=True):
    feature_conf = conf
    feature_conf['model']['max_keypoints'] = 2048
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = dynamic_load(extractors, feature_conf['model']['name'])
    net    = net(feature_conf['model']).eval().to(device)
    return net

def resize_image(image, size, interp):
    if interp.startswith('cv2_'):
        interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper())
        h, w = image.shape[:2]
        if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
            interp = cv2.INTER_LINEAR
        resized = cv2.resize(image, size, interpolation=interp)
    elif interp.startswith('pil_'):
        interp = getattr(PIL.Image, interp[len('pil_'):].upper())
        resized = PIL.Image.fromarray(image.astype(np.uint8))
        resized = resized.resize(size, resample=interp)
        resized = np.asarray(resized, dtype=image.dtype)
    else:
        raise ValueError(
            f'Unknown interpolation {interp}.')
    return resized

def convert_img(img, conf):
    image = img
    size = image.shape[:2][::-1]

    if conf['resize_max'] and (False or max(size) > conf['resize_max']):
        scale = conf['resize_max'] / max(size)
        size_new = tuple(int(round(x*scale)) for x in size)
        image = resize_image(image, size_new, 'cv2_area')

    if conf['grayscale']:
        image = image[None]
    else:
        image = image.transpose((2, 0, 1))  # HxWxC to CxHxW
    image = image / 255.

    data = {
        'image': torch.unsqueeze(torch.from_numpy(image.astype(np.float32)), dim=0),
        'original_size': torch.tensor(size),
    }
    return data

def convert_to_onnx(simplify, model_path):
    # net = generate(extract_features.confs['superpoint_aachen'], onnx=True)
    net = superpoint.SuperPoint(extract_features.confs['superpoint_aachen'])

    device = 'cpu'

    im                  = convert_img(np.zeros((1242, 2208)), extract_features.confs['superpoint_aachen']['preprocessing'])
    print(im['image'].shape)
    input_layer_names   = ["image"]
    output_layer_names  = ["output"]
    
    # Export the model
    print(f'Starting export with onnx {onnx.__version__}.')
    torch.onnx.export(net,
                    im['image'].to(device, non_blocking=True),
                    f               = model_path,
                    verbose         = False,
                    opset_version   = 16,
                    training        = torch.onnx.TrainingMode.EVAL,
                    do_constant_folding = True,
                    input_names     = input_layer_names,
                    output_names    = output_layer_names,
                    dynamic_axes    = None)

    # Checks
    model_onnx = onnx.load(model_path)  # load onnx model
    print("run checker..")
    onnx.checker.check_model(model_onnx)  # check onnx model

    # Simplify onnx
    if simplify:
        import onnxsim
        print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
        model_onnx, check = onnxsim.simplify(
            model_onnx,
            dynamic_input_shape=False,
            input_shapes=None)
        assert check, 'assert check failed'
        onnx.save(model_onnx, model_path)

    print('Onnx model save as {}'.format(model_path))

if __name__ == '__main__':
    convert_to_onnx(True, "/home/plac/Hierarchical-Localization/onnx/test.onnx")

The output from this step appears to be fine.

Loaded SuperPoint model
torch.Size([1, 1, 576, 1024])
Starting export with onnx 1.15.0.
/home/plac/Hierarchical-Localization/third_party/SuperGluePretrainedNetwork/models/superpoint.py:175: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  keypoints = [
/home/plac/Hierarchical-Localization/third_party/SuperGluePretrainedNetwork/models/superpoint.py:178: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
/home/plac/Hierarchical-Localization/third_party/SuperGluePretrainedNetwork/models/superpoint.py:178: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
/home/plac/Hierarchical-Localization/third_party/SuperGluePretrainedNetwork/models/superpoint.py:201: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for k, d in zip(keypoints, descriptors)]
/home/plac/Hierarchical-Localization/third_party/SuperGluePretrainedNetwork/models/superpoint.py:85: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
/home/plac/Hierarchical-Localization/third_party/SuperGluePretrainedNetwork/models/superpoint.py:85: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
/home/plac/.conda/envs/pytorch/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:5408: UserWarning: Exporting aten::index operator of advanced indexing in opset 16 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  warnings.warn(
/home/plac/.conda/envs/pytorch/lib/python3.10/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/plac/.conda/envs/pytorch/lib/python3.10/site-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_graph_shape_type_inference(
/home/plac/.conda/envs/pytorch/lib/python3.10/site-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_graph_shape_type_inference(
run checker..
Simplifying with onnx-simplifier 0.4.36.
Onnx model save as /home/plac/Hierarchical-Localization/onnx/test.onnx

The second step is converting the ONNX model to TensorRT. Below is the code for this step.

# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT


import argparse
import os
from calibrator import DatasetCalibrator
import tensorrt as trt
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument('onnx', type=str, help='Path to the ONNX model.')
parser.add_argument('--output', type=str, default=None, help='Path to output the optimized TensorRT engine')
parser.add_argument('--max_workspace_size', type=int, default=1<<25, help='Max workspace size for TensorRT engine.')
parser.add_argument('--int8', action='store_true')
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--dla_core', type=int, default=0)
parser.add_argument('--gpu_fallback', action='store_true')
parser.add_argument('--dataset_path', type=str, default='/home/plac/slv/data/opensfm/collection/20240317031008/images')
parser.add_argument('--input_w', default=2208)
parser.add_argument('--input_h', default=1242)
args = parser.parse_args()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((args.input_h, args.input_w), interpolation=InterpolationMode.BICUBIC),
])

class RoadSegDataset(Dataset):
    """Road Segmentation dataset."""

    def __init__(self, root, train, transform=None):
        self.root = root
        self.root_files = np.asarray(sorted(os.listdir(root)))
        self.train = train
        self.transform = transform

    def __len__(self):
        return len(self.root_files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root, self.root_files[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image

train_dataset = RoadSegDataset(
    root=args.dataset_path, 
    train=True,
    transform=transform
)

data = torch.zeros(args.batch_size, 3, args.input_h, args.input_w).cuda()

logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
builder.max_batch_size = args.batch_size
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

with open(args.onnx, 'rb') as f:
    parser.parse(f.read())

profile = builder.create_optimization_profile()
profile.set_shape(
    'input',
    (args.batch_size, 3, args.input_h, args.input_w),
    (args.batch_size, 3, args.input_h, args.input_w),
    (args.batch_size, 3, args.input_h, args.input_w)
)

config = builder.create_builder_config()

config.max_workspace_size = args.max_workspace_size

if args.fp16:
    config.set_flag(trt.BuilderFlag.FP16)

if args.int8:
    config.set_flag(trt.BuilderFlag.INT8)
    config.int8_calibrator = DatasetCalibrator(data, train_dataset)

if args.dla_core is not None:
    config.default_device_type = trt.DeviceType.DLA
    config.DLA_core = args.dla_core

if args.gpu_fallback:
    config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
    
config.add_optimization_profile(profile)
config.set_calibration_profile(profile)

engine = builder.build_serialized_network(network, config)

if args.output is not None:
    with open(args.output, 'wb') as f:
        f.write(engine)

The output contains two errors. These errors might be due to TensorRT not supporting certain operations in the conversion process. Could you provide a solution for this?

[04/26/2024-21:38:00] [TRT] [I] [MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 1133, GPU 615 (MiB)
[04/26/2024-21:38:06] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1444, GPU +270, now: CPU 2653, GPU 885 (MiB)
/home/plac/Hierarchical-Localization/onnx2tensorRT.py:70: DeprecationWarning: Use network created with NetworkDefinitionCreationFlag::EXPLICIT_BATCH flag instead.
  builder.max_batch_size = args.batch_size
[04/26/2024-21:38:07] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[04/26/2024-21:38:07] [TRT] [W] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
[04/26/2024-21:38:07] [TRT] [E] /MaxPool: at least 5 dimensions are required for input.
[04/26/2024-21:38:07] [TRT] [E] [graphShapeAnalyzer.cpp::needTypeAndDimensions::2212] Error Code 4: Internal Error (/MaxPool: output shape can not be computed)
/home/plac/Hierarchical-Localization/onnx2tensorRT.py:87: DeprecationWarning: Use set_memory_pool_limit instead.
  config.max_workspace_size = args.max_workspace_size
[04/26/2024-21:38:07] [TRT] [E] 4: [network.cpp::validate::2882] Error Code 4: Internal Error (Network must have at least one output)
Traceback (most recent call last):
  File "/home/plac/Hierarchical-Localization/onnx2tensorRT.py", line 110, in <module>
    f.write(engine)
TypeError: a bytes-like object is required, not 'NoneType'

Here is my environment:

Ubuntu 22.04
Python 3.10.13
pycolmap 0.6.1
onnx 1.15.0
torch 1.13.1
CUDA 11.7
tensorRT 8.6.1
hloc 0.6.1

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant