Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] How to specific aten operators must be run by LibTorch in C++? #2830

Open
demuxin opened this issue May 13, 2024 · 7 comments
Assignees
Labels
question Further information is requested

Comments

@demuxin
Copy link

demuxin commented May 13, 2024

❓ Question

When I compile the SwinTransformer model using Torch-TensorRT, an error appears:

terminate called after throwing an instance of 'c10::Error'
  what():  0 INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":615, please report a bug to PyTorch. We don't have an op for aten::floor_divide but it isn't a special case.  Argument types: int, int, 

Candidates:
        aten::floor_divide(Tensor self, Tensor other) -> Tensor
        aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
        aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
        aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)

I checked out this link, This error is because torch-trt dont support % op.

Fine, I can select to run floor_divide using LibTorch.

torchtrt::ts::CompileSpec compile_settings({ input });
compile_settings.enabled_precisions.insert(build_type);
compile_settings.workspace_size = _1_GB;
compile_settings.truncate_long_and_double = true;
compile_settings.num_avg_timing_iters = 1;
compile_settings.torch_executed_ops.push_back("aten::floor_divide");  // here
torchtrt::ts::compile(model, compile_settings)

It's strange that the setting does not take effect. This error still persists.

What can I do about this mistake?

Furthermore, How to specific aten operators must be run by LibTorch in C++?

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0):2.2.1
  • CPU Architecture:x86
  • OS (e.g., Linux):ubuntu22.04
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:12.2
  • GPU models and configuration:
  • Any other relevant information:
@demuxin demuxin added the question Further information is requested label May 13, 2024
@demuxin
Copy link
Author

demuxin commented May 15, 2024

I came up with this solution. I use this code below to replace % op:

def TakeRemainder(x: int, y: int) -> int:
    return x - y * int(x / y)

And it works.

I want to know why this setting doesn't take effect.

compile_settings.torch_executed_ops.push_back("aten::floor_divide"); 

@gs-olive
Copy link
Collaborator

Hi - thanks for the report. I think this may be related to the following lowering pass, where it's possible that both inputs are upcasted integers, so we accidentally construct a schema which is no longer valid:

case c10::aten::floor_divide:
new_node = g->create(c10::aten::floordiv, user->inputs(), 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;

Regarding why compile_settings.torch_executed_ops.push_back("aten::floor_divide"); doesn't work - this is likely because the lowering pass puts the graph in an inconsistent or invalid state, so it doesn't have the opportunity to exclude conversion of floor_divide before failure, since the "lowering" phase happens prior to partitioning and conversion to TRT/Torch.

@demuxin
Copy link
Author

demuxin commented May 16, 2024

Hi - thanks for the report. I think this may be related to the following lowering pass, where it's possible that both inputs are upcasted integers, so we accidentally construct a schema which is no longer valid:

So this is a bug, right? Will you fix this bug in the future?

@gs-olive
Copy link
Collaborator

Yes, this appears to be bug and we can work on a fix for this. Do you have a reproducer script or model we could use to recreate the error?

@demuxin
Copy link
Author

demuxin commented May 17, 2024

This is code:

torch::Device* device_ = new torch::Device(torch::DeviceType::CUDA);
device_->set_index(0);

torch::jit::script::Module model = torch::jit::load(model_path);
model.to("cuda");
model.eval();
model.to(torch::kHalf);

std::vector<int64_t> input_dim{1, 3, 832, 1440};
auto input = torchtrt::Input(input_dim, torchtrt::DataType::kHalf);

size_t _1_GB = 1 << 30;
torchtrt::ts::CompileSpec compile_settings({ input });
compile_settings.enabled_precisions.insert(torchtrt::DataType::kHalf);
compile_settings.workspace_size = _1_GB;
compile_settings.truncate_long_and_double = true;
compile_settings.num_avg_timing_iters = 1;
torchtrt::ts::compile(model, compile_settings);

Additionally, I provide you with the model with google dirve.

@gs-olive
Copy link
Collaborator

gs-olive commented May 24, 2024

Hello - thanks for the details. I am unable to access the model at that link, is the model available elsewhere? Also, could you provide the full output debug log as well - using the following logging level: torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kGRAPH);?

@demuxin
Copy link
Author

demuxin commented May 24, 2024

I changed the access to the model, The model link is accessible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants