Skip to content

Lemon-cmd/energy-transformer-graph

Repository files navigation

Energy Transformer For Graph Classification

A novel Transformer variant that is both an associative memory model and a continuous dynamical system with a tractable energy that is guaranteed to converge to a fixed point. See our paper for full details. Also, check out other official implementations of our work, see ET for Graph Anamoly Detection (PyTorch), ET for Image (PyTorch), and ET for Image (Jax).

drawing

Installation

pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Note, it is important to read the official Jax installation guide to properly enable GPU and for further details. Additionally, see Jax Versions for installing specific Jax-cuda version. Credits to Ben Hoover for the diagrams.

drawing

Test the install by starting python and running the following code to check whether if GPU is enabled for Jax:

import jax
print(jax.local_devices())

Setting up data

Fortunately, PyTorch Geometric has provided awesome datasets and dataloaders which will automatically download datasets when code is ran. Simply change the provided dataset name for TUDataset or GNNBenchmark.

model_name = data_name = 'CIFAR10'
train_data = GNNBenchmarkDataset(root = '../data/', name = data_name, split = 'train')

See if it works

Simply, navigate to the nbs folder for the provided Jupyter notebooks to run the experiments.

./run_nb_inplace nbs/eval_cifar10.ipynb

Training from scratch

Since there are a number of provided pretrained models, please ensure that such files are removed or stored in a different folder such that they won't be reloaded.

./run_nb_inplace nbs/cifar10.ipynb

Pretrained Models

Some pretrained models are provided in the saved_models folder. To download the rest of the pretrained models, see Google Drive Link.

Citation

if you find the code or the work useful, please cite our work!

@article{hoover2023energy,
  title={Energy Transformer},
  author={Hoover, Benjamin and Liang, Yuchen and Pham, Bao and Panda, Rameswar and Strobelt, Hendrik and Chau, Duen Horng and Zaki, Mohammed J and Krotov, Dmitry},
  journal={arXiv preprint arXiv:2302.07253},
  year={2023}
}

Releases

No releases published

Packages

No packages published