Skip to content

jeyabbalas/tabnet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TensorFlow 2 Keras implementation of TabNet

A TensorFlow 2 Keras implementation of TabNet from the paper: TabNet: Attentive Interpretable Tabular Learning. The authors propose TabNet, a neural network architecture capable of learning a canonical representation of tabular data. This architecture is shown to perform competitively or better than the state-of-the-art tabular data learning methods like XGBoost, CatBoost, and LightGBM. TabNet is also interpretable i.e., they can generate both global and individualized feature importance.

Citation: Arık, S. O., & Pfister, T. (2020). Tabnet: Attentive interpretable tabular learning. arXiv.

This implementation closely follows the TabNet implementation in PyTorch linked here. The description of that implementation is explained in this helpful video by Sebastian Fischman. In my opinion, this is the most reliable and flexible implementation of TabNet that I could find. I was unable to find any good, reliable, and flexible implementation of TabNet in TensorFlow.

I re-implement TabNet in TensorFlow 2 Keras here mainly to enable the re-use and experimentation with this architecture from within the TensorFlow ecosystem and to be able to take advantage of the Keras API.

Usage

Note

The current TensorFlow implementation of Ghost Batch Normalization requires the virtual batch size to be a factor of the overall batch size even at inference time. This implementation is incorrect. As a result, I don't recommend using Ghost Batch Normalization (by setting TabNet parameter virtual_batch_size = None). Track this issue here.

About

A TensorFlow 2 Keras implementation of TabNets.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages