We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
The code below produces the following error:
SpecViolationError: Node.meta _to_copy_default is missing val field.
The same code works fine with Torch-TensorRT 2.1.0-rc9 and PyTorch 2.1.2.
import torch import torch_tensorrt class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, topk_ind): gather_index = topk_ind.unsqueeze(-1) return gather_index def main(): model = Model().cuda() model.eval() enc_outputs_class = torch.randn(1, 8400, 80).cuda() num_queries = 300 _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, num_queries, dim=1) inputs = [ torch_tensorrt.Input(topk_ind.shape, dtype=torch.int64), ] enabled_precisions = {torch.half, torch.float32} trt_model = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, truncate_long_and_double=True, min_block_size=1, ) if __name__ == "__main__": main()
Build information about Torch-TensorRT can be found by turning on debug messages
conda
pip
libtorch
The text was updated successfully, but these errors were encountered:
Notably, the below code works:
import torch import torch_tensorrt class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, topk_ind): gather_index = topk_ind.unsqueeze(-1) return gather_index def main(): model = Model().cuda() model.eval() topk_ind = torch.randint(8400, size=(1, 300)).cuda() inputs = [ torch_tensorrt.Input(topk_ind.shape, dtype=torch.int32), ] enabled_precisions = {torch.half, torch.float32} trt_model = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, truncate_long_and_double=True, min_block_size=1, ) if __name__ == "__main__": main()
which suggests it's not the unsqueeze causing problems. The original code I was testing that raised the error was:
unsqueeze
import torch import torch_tensorrt class Model(torch.nn.Module): def __init__(self): super().__init__() self.num_queries = 300 def forward(self, enc_outputs_class): _, topk_ind = torch.topk( enc_outputs_class.max(-1).values, self.num_queries, dim=1 ) gather_index = topk_ind.unsqueeze(-1) return gather_index def main(): model = Model().cuda() model.eval() enc_outputs_class = torch.randn(1, 8400, 80).cuda() inputs = [ torch_tensorrt.Input(enc_outputs_class.shape), ] enabled_precisions = {torch.half, torch.float32} trt_model = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, truncate_long_and_double=True, min_block_size=1, ) if __name__ == "__main__": main()
so it seems like it has something to do with topk.
topk
Sorry, something went wrong.
Adding output_format="torchscript" seems to solve the issue:
output_format="torchscript"
import torch import torch_tensorrt class Model(torch.nn.Module): def __init__(self): super().__init__() self.num_queries = 300 def forward(self, enc_outputs_class): _, topk_ind = torch.topk( enc_outputs_class.max(-1).values, self.num_queries, dim=1 ) gather_index = topk_ind.unsqueeze(-1) return gather_index def main(): model = Model().cuda() model.eval() enc_outputs_class = torch.randn(1, 8400, 80).cuda() inputs = [ torch_tensorrt.Input(enc_outputs_class.shape), ] enabled_precisions = {torch.half, torch.float32} trt_model = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, truncate_long_and_double=True, min_block_size=1, output_format="torchscript", ) if __name__ == "__main__": main()
No branches or pull requests
Bug Description
The code below produces the following error:
The same code works fine with Torch-TensorRT 2.1.0-rc9 and PyTorch 2.1.2.
To Reproduce
Expected behavior
Environment
conda
,pip
,libtorch
, source): pipAdditional context
The text was updated successfully, but these errors were encountered: