Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save TabNet in ONNX format #277

Open
luigisaetta opened this issue Mar 28, 2021 · 20 comments
Open

Save TabNet in ONNX format #277

luigisaetta opened this issue Mar 28, 2021 · 20 comments
Assignees
Labels
enhancement New feature or request

Comments

@luigisaetta
Copy link

Feature request

Is it possible to save TabNet in ONNX format?

What is the expected behavior?

What is motivation or use case for adding/changing the behavior?
ONNX is quickly becoming the de-facto standard to save models, even because these way you avoid to import packages when you want to pack for inference.

How should this be implemented in your opinion?

Are you willing to work on this yourself?
well, for now don't have a precise idea, but willing to give some help if I have some suggestion where to start.

The feature probably could be implemented as a NotBook example, therefore with no needed changes to the core implementation.

@luigisaetta luigisaetta added the enhancement New feature or request label Mar 28, 2021
@Optimox
Copy link
Collaborator

Optimox commented Mar 28, 2021

I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only).

Training the model and then saving the model.network in eval mode should work without problem. I think ONNX would work the same way. But this production ready requirements can be specific to each environment (some will have python inside docker with the library available for inference, some only C++, some will be interested in getting the explanations with the predictions while some will only care about predictions). So I feel we could only give some examples on how this would work, but since the network is accessible and is just a simple torch.nn.Module I feel that it is a bit beyond the scope of the library.

Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:

def save_torch_script(tab_model, X_infer_ex, saving_path, model_name):
    """
    Utility function to save tabnet model as torch script

    Parameters
    ----------
    - tab_model : pytorch-tabnet model 
        A trained network to save
    - X_infer_ex : EDIT torch.Tensor of size (B, D) and not 2D np.array (B, D)
        Batch containing B examples with D features 
    - model_name : str
        Name of the file to create, shoud not contain extension name
    - saving_path : str
        Path to save the file
    Returns
    -------
    - traced_script_module : torch.jit.trace
        Traced model that has been saved
    """
    tab_model.network.eval()
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(tab_model.network, X_infer_ex)
    traced_script_module.save(saving_path + model_name + ".pt")
    return traced_script_module

I think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations)

@luigisaetta
Copy link
Author

Hi Optimo, thanks for your quick answer.

My main use case is to be able to pack the trained model in a REST service for predictions. In Python.

As far as I understand (I'm not a great expert of PyTorch) a TabNetClassifier is a torch.nn.Module, so as explained here:

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

we should be able to export through tracing, using torch.onnx.export

as soon as I have time, I'll follow, need to see what is the meaning of the params

I agree with you that adding onnx to TabNet is not what should be done, I was thinking to add an example of Notebooks and best practices.

@luigisaetta
Copy link
Author

luigisaetta commented Mar 29, 2021

It seems more difficult than I expected. When I call:

torch.onnx.export(clf.network, dummy_input, "tabnet1", verbose=True)

I get the following error:

RuntimeError: Only tuples, lists, and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type numpy.ndarray

It seems:

  • cannot export if the input is NumPy array
  • but TabNetClassifier doesn't accept Torch tensor as input

What kind of data can I pass as input? Only NumPy array? Strange enough since under the cover it is PyTorch. Should be easy to accept Tensor

@Hartorn
Copy link
Contributor

Hartorn commented Mar 29, 2021

Hi,

I will try to see how this can work.
To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?

@luigisaetta
Copy link
Author

Hi,

I will try to see how this can work.
To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?

Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime

@Hartorn
Copy link
Contributor

Hartorn commented Mar 29, 2021

Hi,
I will try to see how this can work.
To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?

Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime

For that, you can use pytorch save method, but I will come back with some tests using onnx

@Optimox
Copy link
Collaborator

Optimox commented Mar 29, 2021

@luigisaetta I think trying to ONNXify (whatever this is called) the entire class TabNetClassifier is doomed to fail - I'll be very surprised if you manage to export everything with ONNX or jit (@Hartorn I know you are full of surprise :) ).
I think you should focus on exporting the network only, which is accessible with network() method.

TabNetClassfier does not take tensors as input, but the network does so it's weird.

@Hartorn
Copy link
Contributor

Hartorn commented Mar 29, 2021

@Optimox You are right, only the network should be exported, but I have to check also how the input format is (only one input or several for the embeddings and so on)

I will see if we can setup some kind of optional deps to have a custom exporter, or at least to have a notebook concerning this
I should manage to get something working, based on

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html#
http://onnx.ai/sklearn-onnx/auto_examples/plot_custom_model.html#sphx-glr-auto-examples-plot-custom-model-py
http://onnx.ai/sklearn-onnx/auto_examples/plot_custom_parser_alternative.html#sphx-glr-auto-examples-plot-custom-parser-alternative-py
http://onnx.ai/sklearn-onnx/auto_examples/plot_pipeline_lightgbm.html#sphx-glr-auto-examples-plot-pipeline-lightgbm-py

Will try to have a look this week.

@luigisaetta
Copy link
Author

@luigisaetta I think trying to ONNXify (whatever this is called) the entire class TabNetClassifier is doomed to fail - I'll be very surprised if you manage to export everything with ONNX or jit (@Hartorn I know you are full of surprise :) ).
I think you should focus on exporting the network only, which is accessible with network() method.

TabNetClassfier does not take tensors as input, but the network does so it's weird.

Hi, as you can see from my comment above, I get error when I apply the torch.onnx.export to clf.network. So, as far as I understand I'm trying to export only the network. But, since it wants an input... it doesn't accept NumPy array.

@Optimox
Copy link
Collaborator

Optimox commented Apr 2, 2021

@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors?

@luigisaetta
Copy link
Author

@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors?

No, it doesn't. I'll find my test and post here the error. Basically, I think it calls a method that exists in numpy array but not in Torch tensor.

@luigisaetta
Copy link
Author

@Optimox I have published an article on TowardsDataScience, https://towardsdatascience.com/pytorch-tabnet-integration-with-mlflow-cb14f3920cb0
There, I'm talking about pytorch-tabnet and integration with MLflow.

@Optimox
Copy link
Collaborator

Optimox commented Apr 16, 2021

Hello @luigisaetta,

Great article, very detailed! I'm happy to see that integration with MLflow is made so easy by the callbacks. @queraq worked on this and I'm sure neither of us had this specific usage in mind.

There is one part of the article where I think a bit of clarification would be welcomed: the Encoder-Decoder part. In fact TabNet models are only sort of encoders (plus all the sequential attention part), there is no decoder at all. The only reason there exists a decoder part is to enable self-supervised pre-training, which needs a decoder. Since you do not mention pre-training in the article I think you should not talk about a decoder-encoder model, or maybe you could add a paragraph about TabNetPretrainer.

Anyway great article, thanks for sharing with us and giving credits to the repo.

But... the article does not tell me if you managed to get ONNX format working?! :)

Cheers!

@luigisaetta
Copy link
Author

@Optimox regarding onnx, no I didn't make any progress. The point where I become blocked is that TabNet doesn't seem to accept tensors as input. I think in the code it calls some methods existing only for numpy array. Have you any suggestion?

@Optimox
Copy link
Collaborator

Optimox commented Apr 16, 2021

hmm actually I think I know.

If you have a look a this file https://github.com/dreamquark-ai/tabnet/blob/develop/pytorch_tabnet/tab_network.py where everything happens about the network, we are actually using numpy for some stupid reason (laziness and bad habits mainly).

I think it would be very easy to replace all the np.any_function in this code by the torch equivalent.
This might have two positive effects:

  • allow ONNX format
  • probably speed up the code with GPU

I don't have much time at the moment but I'll definitely change that. If you want to make those changes and see if it works for ONNX don't hesitate. You can also open a PR and I'll review it carefully.

Otherwise I'll do this as soon as I can or maybe @eduardocarvp will have a look before me?

I think we might have found your problem :)

@rxbh2019
Copy link

rxbh2019 commented Jul 25, 2021

Hi, has anyone made any progress on this yet? would be really appreciated if you could share a bit about how the export would work. Right now I am stuck at exporting and the error tells me this:
"ONNX export failed: Couldn't export Python operator Entmax15Function"

Could anyone help?

@mythicaa
Copy link

I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only).

Training the model and then saving the model.network in eval mode should work without problem. I think ONNX would work the same way. But this production ready requirements can be specific to each environment (some will have python inside docker with the library available for inference, some only C++, some will be interested in getting the explanations with the predictions while some will only care about predictions). So I feel we could only give some examples on how this would work, but since the network is accessible and is just a simple torch.nn.Module I feel that it is a bit beyond the scope of the library.

Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:

def save_torch_script(tab_model, X_infer_ex, saving_path, model_name):
    """
    Utility function to save tabnet model as torch script

    Parameters
    ----------
    - tab_model : pytorch-tabnet model 
        A trained network to save
    - X_infer_ex : EDIT torch.Tensor of size (B, D) and not 2D np.array (B, D)
        Batch containing B examples with D features 
    - model_name : str
        Name of the file to create, shoud not contain extension name
    - saving_path : str
        Path to save the file
    Returns
    -------
    - traced_script_module : torch.jit.trace
        Traced model that has been saved
    """
    tab_model.network.eval()
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(tab_model.network, X_infer_ex)
    traced_script_module.save(saving_path + model_name + ".pt")
    return traced_script_module

I think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations)

Hi! I tried this but I get this error:
Could not export Python function call 'SparsemaxFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants:
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py(109): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(640): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(160): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(471): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(586): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/jit/_trace.py(967): trace_module
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/jit/_trace.py(750): trace
(23):
/databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3437): run_code
/databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3357): run_ast_nodes
/databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3165): run_cell_async
/databricks/python/lib/python3.8/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2940): _run_cell
/databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2894): run_cell
/databricks/python_shell/scripts/PythonShellImpl.py(757): run_cell
/databricks/python_shell/scripts/PythonShellImpl.py(269): run
/databricks/python_shell/scripts/PythonShellImpl.py(1234): launch_process
/databricks/python_shell/scripts/PythonShell.py(29):

Is there any way to fix this?

@Optimox
Copy link
Collaborator

Optimox commented Sep 30, 2022

Yes you probably need to 'scriptify' sparsemax (and entmax functions) so that they can be accepted for tracing.

I don't know how hard it would be, you can try adding @script on top of the definition of sparsemax and entmax and see if it works.

@mythicaa
Copy link

mythicaa commented Sep 30, 2022

Thanks for the reply. Im trying to speed up inference and torchscript is one way I was trying. Is there any other more straightforward method you would suggest to speed it up before I try this for torchscript?

@duanckham
Copy link

Oh guys, is it still can not export to ONNX right now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

8 participants