Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃悰 [Bug] torch.export.load fails to load file saved with torch_tensorrt.save in PyTorch 2.4.0.dev #2791

Closed
HolyWu opened this issue Apr 28, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Apr 28, 2024

Bug Description

DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%x, [2, 3]), kwargs = {})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %view, %permute), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm, [1, 1, 2, 3]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%x, [2, 3]), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %view, %_frozen_param0), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm, [1, 1, 2, 3]), kwargs = {})
    return (view_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.view_to_reshape:Graph after replacing view with reshape:
graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [2, 3]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %reshape_default, %_frozen_param0), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 1, 2, 3]), kwargs = {})
    return (reshape_default_1,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [2, 3]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %reshape_default, %_frozen_param0), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 1, 2, 3]), kwargs = {})
    return (reshape_default_1,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 2
- torch.ops.aten.addmm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 3 operators out of 3 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.reshape.default + Operator Count: 2
- torch.ops.aten.addmm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(1, 1, 2, 3)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [2, 3]), kwargs = {})
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %reshape_default, %_frozen_param0), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 1, 2, 3]), kwargs = {})
    return reshape_default_1
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 13575, GPU 1045 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2666, GPU +308, now: CPU 16477, GPU 1353 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Graph to be compiled to TensorRT: graph():
    %x : [num_users=1] = placeholder[target=x]
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%x, [2, 3]), kwargs = {})
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %reshape_default, %_frozen_param0), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 1, 2, 3]), kwargs = {})
    return reshape_default_1
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 1, 2, 3], dtype=DataType.HALF]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /linear/reshape_default (kind: aten.reshape.default, args: ('x <tensorrt.ITensor [shape=(1, 1, 2, 3), dtype=DataType.HALF]>', [2, 3]))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /linear/addmm (kind: aten.addmm.default, args: ('<torch.Tensor as np.ndarray [shape=(3,), dtype=float16]>', '[SHUFFLE]-[aten_ops.reshape.default]-[/linear/reshape_default]_output <tensorrt.ITensor [shape=(2, 3), dtype=DataType.HALF]>', '<torch.Tensor as np.ndarray [shape=(3, 3), dtype=float16]>'))
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Freezing tensor /linear/addmm_constant_0 to TRT IConstantLayer
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /linear/reshape_default_1 (kind: aten.reshape.default, args: ('[ELEMENTWISE]-[aten_ops.addmm.default]-[/linear/addmm_add]_output_addmm.default <tensorrt.ITensor [shape=(2, 3), dtype=DataType.HALF]>', [1, 1, 2, 3]))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 1, 2, 3), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003905
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 7648
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 4 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.085ms to assign 4 blocks to 4 nodes requiring 2048 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 1024
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 274
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 2.42034 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 1 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3798 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 130 timing cache entries
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:02.423098
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 118436 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 3 Total Operators, of which 3 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

  Graph Structure:

   Inputs: Tuple(Tensor: (1, 1, 2, 3)@float16)
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 1, 2, 3)@float16]
     Number of Operators in Engine: 3
     Engine Outputs: Tensor: (1, 1, 2, 3)@float16
    ...
   Outputs: List[Tensor: (1, 1, 2, 3)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 3.0
   Most Operators in a TRT Engine: 3

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=3 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 1 TRT engine(s)

model(*inputs):
tensor([[[[ 0.4319, -0.1377,  1.3047],
          [ 0.4023,  0.2119, -0.0819]]]], device='cuda:0', dtype=torch.float16)
loaded_eager(*inputs):
tensor([[[[ 0.4319, -0.1377,  1.3047],
          [ 0.4023,  0.2119, -0.0819]]]], device='cuda:0', dtype=torch.float16)
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
optimized_model(*inputs):
tensor([[[[ 0.4319, -0.1377,  1.3047],
          [ 0.4023,  0.2119, -0.0819]]]], device='cuda:0', dtype=torch.float16)

WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\_exporter.py:364: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:C:\Python312\Lib\site-packages\torch\fx\graph.py:1536: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 38, in <module>
    loaded_trt = torch.export.load("trt.ep").module()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\export\__init__.py", line 299, in load
    return load(
           ^^^^^
  File "C:\Python312\Lib\site-packages\torch\_export\__init__.py", line 315, in load
    ep = deserialize(artifact, expected_opset_version)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_export\serde\serialize.py", line 2388, in deserialize
    .deserialize(
     ^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_export\serde\serialize.py", line 2213, in deserialize
    .deserialize(
     ^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_export\serde\serialize.py", line 1811, in deserialize
    self.example_inputs = deserialize_torch_artifact(example_inputs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_export\serde\serialize.py", line 314, in deserialize_torch_artifact
    assert isinstance(artifact, (tuple, dict))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

To Reproduce

import torch
import torch.nn as nn
import torch_tensorrt


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 3)

    def forward(self, x):
        return self.linear(x)


with torch.inference_mode():
    model = MyModule().eval().cuda().half()
    inputs = (torch.randn((1, 1, 2, 3), dtype=torch.half, device="cuda"),)

    torch.export.save(torch.export.export(model, inputs), "eager.ep")
    loaded_eager = torch.export.load("eager.ep").module()

    optimized_model = torch_tensorrt.compile(
        model,
        ir="dynamo",
        inputs=inputs,
        enabled_precisions={torch.half},
        debug=True,
        min_block_size=1,
    )

    print("")
    print(f"model(*inputs):\n{model(*inputs)}")
    print(f"loaded_eager(*inputs):\n{loaded_eager(*inputs)}")
    print(f"optimized_model(*inputs):\n{optimized_model(*inputs)}")
    print("")

    torch_tensorrt.save(optimized_model, "trt.ep", output_format="exported_program", inputs=inputs)
    loaded_trt = torch.export.load("trt.ep").module()
    print(f"loaded_trt(*inputs):\n{loaded_trt(*inputs)}")

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 1a4ffe4
  • PyTorch Version (e.g. 1.0): 2.4.0.dev20240427+cu121
  • CPU Architecture: x64
  • OS (e.g., Linux): Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12
  • CUDA version: 12.4
  • GPU models and configuration: GeForce RTX 3050
  • Any other relevant information:
@HolyWu HolyWu added the bug Something isn't working label Apr 28, 2024
@HolyWu
Copy link
Contributor Author

HolyWu commented Apr 28, 2024

Hmm...it works when building against PyTorch 2.3.0. So currently it's broken with PyTorch 2.4.0.dev.

@HolyWu HolyWu changed the title 馃悰 [Bug] torch.export.load fails to load file saved with torch_tensorrt.save 馃悰 [Bug] torch.export.load fails to load file saved with torch_tensorrt.save in PyTorch 2.4.0.dev Apr 29, 2024
@peri044
Copy link
Collaborator

peri044 commented Apr 29, 2024

Yes, this is a known issue and we brought it up with Pytorch. Investigation ongoing.

@HolyWu
Copy link
Contributor Author

HolyWu commented May 16, 2024

Looks like it's already fixed in the latest PyTorch nightly build. Maybe you can re-enable deserialization in test_export_serde.

@HolyWu HolyWu closed this as completed May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants