Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃悰 [Bug] TensorRT engine exceptions are not raised #2367

Open
ralbertazzi opened this issue Oct 6, 2023 · 2 comments 路 May be fixed by #2709
Open

馃悰 [Bug] TensorRT engine exceptions are not raised #2367

ralbertazzi opened this issue Oct 6, 2023 · 2 comments 路 May be fixed by #2709
Labels
bug Something isn't working No Activity

Comments

@ralbertazzi
Copy link

Bug Description

I would expect for torch to raise an exception when inference fails for any reason, such as wrong input tensor shape or wrong dtype. Instead, a warning is raised in console but the program continues successfully. This can have serious implications in production environments.

To Reproduce

I have a model compiled on float16 that accepts a static input shape of (1, 3, 538, 538)

import torch
import torch_tensorrt

model = torch.jit.load("model.ts")
input_tensor = torch.zeros((1, 3, 538, 538), dtype=torch.float16, device="cuda")
output_tensor = model(input_tensor)

This is what happens if I pass a wrong shape

>>> out = model(torch.zeros((1, 3, 500, 500), dtype=torch.float16, device="cuda"))
ERROR: [Torch-TensorRT] - 3: [executionContext.cpp::setInputShape::2020] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::setInputShape::2020, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape.
)
>>> # NO EXCEPTION IS RAISED

This is what happens if I pass a wrong dtype

>>> out = model(torch.zeros((1, 3, 538, 538), dtype=torch.float32, device="cuda"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/wizard/mambaforge/envs/remini/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/<model>.py", line 8, in forward
    input_0: Tensor) -> Tensor:
    __torch___<model>_trt_engine_ = self_1.__torch___<model>_trt_engine_
    _0 = ops.tensorrt.execute_engine([input_0], __torch___<model>_trt_engine_)
         ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _1, = _0
    return _1

Traceback of TorchScript, original code (most recent call last):
RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:136] Expected inputs[i].dtype() == expected_type to be true but got false
Expected input tensors to have type Half, found type float


>>> # NO EXCEPTION IS RAISED

Expected behavior

An exception should be raised if the TensorRT Engine returns an error.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.4.0
  • PyTorch Version (e.g. 1.0): 2.0.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): LInux
  • How you installed PyTorch (conda, pip, libtorch, source): pip (custom whl)
  • Build command you used (if compiling from source): -
  • Are you using local sources or building from archives: -
  • Python version: 3.9
  • CUDA version: 11.8
  • GPU models and configuration: Tesla T4 / Tesla L4
  • Any other relevant information: tensorrt 8.5.3.1
@ralbertazzi ralbertazzi added the bug Something isn't working label Oct 6, 2023
@ralbertazzi ralbertazzi changed the title 馃悰 [Bug] Encountered bug when using Torch-TensorRT 馃悰 [Bug] TensorRT engine exceptions are not raised Oct 6, 2023
Copy link

github-actions bot commented Jan 5, 2024

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@ralbertazzi
Copy link
Author

Here's a comment to keep the issue alive.

@gcuendet gcuendet linked a pull request Mar 25, 2024 that will close this issue
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working No Activity
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant