Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃悰 [Bug] SpecViolationError: Node.meta _to_copy_default is missing val field when using unsqueeze #2799

Open
airalcorn2 opened this issue Apr 30, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@airalcorn2
Copy link

Bug Description

The code below produces the following error:

SpecViolationError: Node.meta _to_copy_default is missing val field.

The same code works fine with Torch-TensorRT 2.1.0-rc9 and PyTorch 2.1.2.

To Reproduce

import torch
import torch_tensorrt


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, topk_ind):
        gather_index = topk_ind.unsqueeze(-1)
        return gather_index


def main():
    model = Model().cuda()
    model.eval()
    enc_outputs_class = torch.randn(1, 8400, 80).cuda()
    num_queries = 300
    _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, num_queries, dim=1)

    inputs = [
        torch_tensorrt.Input(topk_ind.shape, dtype=torch.int64),
    ]
    enabled_precisions = {torch.half, torch.float32}
    trt_model = torch_tensorrt.compile(
        model,
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        truncate_long_and_double=True,
        min_block_size=1,
    )


if __name__ == "__main__":
    main()

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.2.0
  • PyTorch Version (e.g. 1.0): 2.2.0
  • CPU Architecture: i7-12800H
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.12
  • CUDA version: 12.2
  • GPU models and configuration: GeForce RTX 3080 Ti
  • Any other relevant information:

Additional context

@airalcorn2 airalcorn2 added the bug Something isn't working label Apr 30, 2024
@airalcorn2
Copy link
Author

Notably, the below code works:

import torch
import torch_tensorrt


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, topk_ind):
        gather_index = topk_ind.unsqueeze(-1)
        return gather_index


def main():
    model = Model().cuda()
    model.eval()
    topk_ind = torch.randint(8400, size=(1, 300)).cuda()
    inputs = [
        torch_tensorrt.Input(topk_ind.shape, dtype=torch.int32),
    ]
    enabled_precisions = {torch.half, torch.float32}
    trt_model = torch_tensorrt.compile(
        model,
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        truncate_long_and_double=True,
        min_block_size=1,
    )


if __name__ == "__main__":
    main()

which suggests it's not the unsqueeze causing problems. The original code I was testing that raised the error was:

import torch
import torch_tensorrt


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.num_queries = 300

    def forward(self, enc_outputs_class):
        _, topk_ind = torch.topk(
            enc_outputs_class.max(-1).values, self.num_queries, dim=1
        )
        gather_index = topk_ind.unsqueeze(-1)
        return gather_index


def main():
    model = Model().cuda()
    model.eval()
    enc_outputs_class = torch.randn(1, 8400, 80).cuda()
    inputs = [
        torch_tensorrt.Input(enc_outputs_class.shape),
    ]
    enabled_precisions = {torch.half, torch.float32}
    trt_model = torch_tensorrt.compile(
        model,
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        truncate_long_and_double=True,
        min_block_size=1,
    )


if __name__ == "__main__":
    main()

so it seems like it has something to do with topk.

@airalcorn2
Copy link
Author

Adding output_format="torchscript" seems to solve the issue:

import torch
import torch_tensorrt


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.num_queries = 300

    def forward(self, enc_outputs_class):
        _, topk_ind = torch.topk(
            enc_outputs_class.max(-1).values, self.num_queries, dim=1
        )
        gather_index = topk_ind.unsqueeze(-1)
        return gather_index


def main():
    model = Model().cuda()
    model.eval()
    enc_outputs_class = torch.randn(1, 8400, 80).cuda()
    inputs = [
        torch_tensorrt.Input(enc_outputs_class.shape),
    ]
    enabled_precisions = {torch.half, torch.float32}
    trt_model = torch_tensorrt.compile(
        model,
        inputs=inputs,
        enabled_precisions=enabled_precisions,
        truncate_long_and_double=True,
        min_block_size=1,
        output_format="torchscript",
    )


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant