Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd Function error with torch.script #485

Open
RobColeman opened this issue May 24, 2023 · 1 comment
Open

Autograd Function error with torch.script #485

RobColeman opened this issue May 24, 2023 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@RobColeman
Copy link

RobColeman commented May 24, 2023

When trying to scriptify Tabnet for portability in training, there are compilation errors in the SparseMax and EntMax15 functions.

Could not export Python function call 'Entmax15Function'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

This is a known issue in Pytorch with script compilation of torch.autograd.Function which looks to be triaged with no intention to fix.
pytorch/pytorch#22329

What is the current behavior?
Compilation into torch.script will fail

If the current behavior is a bug, please provide the steps to reproduce.

scripted_tabnet_network = torch.jit.script(tabnet_model.network)

or

traced_script_module = torch.jit.trace(
    tabnet_model.network, input_features_example
)

Expected behavior
Should compile into jit.script or jit.trace

Screenshots

Other relevant information:
tabnet version: 3.1.1 & 4.0
python version: 3.8+

Suggested fix

Include an alternative implementation of Softmax and EntMax15 which do not use torch.autograd.Function

@RobColeman RobColeman added the bug Something isn't working label May 24, 2023
@Optimox
Copy link
Collaborator

Optimox commented May 25, 2023

Hello @RobColeman,

Do you absolutely need a torch script for your production environment ? If you have python install in production, then you can install pytorch-tabnet and just use the save/load framework.

Otherwise please don't hesitate to open a PR with an updated version of Entmax and Sparsemax that are scriptable.

@Optimox Optimox added enhancement New feature or request and removed bug Something isn't working labels May 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants