Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃悰 [Bug] torch.export.load fails to load dynamic-shaped model #2792

Open
HolyWu opened this issue Apr 29, 2024 · 0 comments
Open

馃悰 [Bug] torch.export.load fails to load dynamic-shaped model #2792

HolyWu opened this issue Apr 29, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Apr 29, 2024

Bug Description

It works fine with static-shaped model, but fails to load with dynamic-shaped one.

DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%arg0_1,), kwargs = {})
    return (relu,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%arg0_1,), kwargs = {})
    return (relu,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%arg0_1,), kwargs = {})
    return (relu,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.relu.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.relu.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [{'min_shape': (1, 3, 1, 1), 'opt_shape': (1, 3, 16, 16), 'max_shape': (1, 3, 128, 128)}]
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%arg0_1,), kwargs = {})
    return relu
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +4, GPU +0, now: CPU 12512, GPU 1009 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2651, GPU +308, now: CPU 15387, GPU 1317 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Graph to be compiled to TensorRT: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%arg0_1,), kwargs = {})
    return relu
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[1, 3, -1, -1], dtype=DataType.FLOAT]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /m/relu (kind: aten.relu.default, args: ('arg0_1 <tensorrt.ITensor [shape=(1, 3, -1, -1), dtype=DataType.FLOAT]>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, -1, -1), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.000980
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 256
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 2 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.0798ms to assign 2 blocks to 2 nodes requiring 1024 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 1024
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 1.34254 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 4 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3558 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 16 timing cache entries
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:01.343781
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 16644 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

  Graph Structure:

   Inputs: List[Tensor: {'min_shape': (1, 3, 1, 1), 'opt_shape': (1, 3, 16, 16), 'max_shape': (1, 3, 128, 128)}@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: {'min_shape': (1, 3, 1, 1), 'opt_shape': (1, 3, 16, 16), 'max_shape': (1, 3, 128, 128)}@float32]
     Number of Operators in Engine: 1
     Engine Outputs: Tensor: (1, 3, 16, 16)@float32
    ...
   Outputs: List[Tensor: (1, 3, 16, 16)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:C:\Python311\Lib\site-packages\torch_tensorrt\dynamo\_exporter.py:364: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:C:\Python311\Lib\site-packages\torch\fx\graph.py:1460: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 36, in <module>
    loaded_model = torch.export.load("trt.ep").module()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\export\__init__.py", line 299, in load
    return load(
           ^^^^^
  File "C:\Python311\Lib\site-packages\torch\_export\__init__.py", line 304, in load
    ep = deserialize(artifact, expected_opset_version)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_export\serde\serialize.py", line 1999, in deserialize
    .deserialize(
     ^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_export\serde\serialize.py", line 1829, in deserialize
    .deserialize(
     ^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_export\serde\serialize.py", line 1502, in deserialize
    self.deserialize_graph(serialized_graph_module.graph)
  File "C:\Python311\Lib\site-packages\torch\_export\serde\serialize.py", line 1301, in deserialize_graph
    meta_val = self.deserialize_tensor_meta(tensor_value)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_export\serde\serialize.py", line 1274, in deserialize_tensor_meta
    torch.empty_strided(
  File "C:\Python311\Lib\site-packages\torch\utils\_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_subclasses\fake_tensor.py", line 896, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_subclasses\fake_tensor.py", line 1241, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_subclasses\fake_tensor.py", line 974, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_subclasses\fake_tensor.py", line 1431, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_subclasses\fake_impls.py", line 179, in constructors
    r = func(*args, **new_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\_ops.py", line 594, in __call__
    return self_._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\sym_node.py", line 400, in expect_size
    r = b.expect_true(file, line)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\sym_node.py", line 386, in expect_true
    return self.guard_bool(file, line)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\sym_node.py", line 374, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\recording.py", line 231, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 4115, in evaluate_expr
    static_expr = self._maybe_evaluate_static(expr,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 1144, in wrapper
    return fn_cache(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 3526, in _maybe_evaluate_static
    vr = self.var_to_range[k]
         ~~~~~~~~~~~~~~~~~^^^
KeyError: s0

To Reproduce

import torch
import torch.nn as nn
import torch_tensorrt


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = nn.ReLU()

    def forward(self, x):
        return self.m(x)


with torch.inference_mode():
    model = MyModule().eval().cuda()
    inputs = [
        torch_tensorrt.Input(
            min_shape=(1, 3, 1, 1), opt_shape=(1, 3, 16, 16), max_shape=(1, 3, 128, 128), dtype=torch.float, name="x"
        )
    ]

    optimized_model = torch_tensorrt.compile(
        model,
        ir="dynamo",
        inputs=inputs,
        enabled_precisions={torch.float},
        debug=True,
        min_block_size=1,
    )

    x = torch.randn((1, 3, 8, 8), dtype=torch.float, device="cuda")
    optimized_model(x)

    torch_tensorrt.save(optimized_model, "trt.ep", inputs=[x])
    loaded_model = torch.export.load("trt.ep").module()
    loaded_model(x)

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 1a4ffe4
  • PyTorch Version (e.g. 1.0): 2.3.0+cu121
  • CPU Architecture: x64
  • OS (e.g., Linux): Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.9
  • CUDA version: 12.4
  • GPU models and configuration: GeForce RTX 3050
  • Any other relevant information:
@HolyWu HolyWu added the bug Something isn't working label Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant