Skip to content

Second-order differentiable PyTorch GRUs in JIT with TorchScript

License

Notifications You must be signed in to change notification settings

Maghoumi/JitGRU

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 

Repository files navigation

JitGRU: GRU with PyTorch's TorchScript

A simple implementation of GRUs using PyTorch's JIT (TorchScript). The API follows that of torch.nn.GRU. Should run reasonably fast.

But... why?

At the time of writing, PyTorch does not support second order derivatives for GRUs with CUDA (see this issue). As a result, any loss function that depends on computing the second derivatives of GRUs doesn't work on out of the box. I needed double backward() calls for a project, so here it is!

How to use

The main implementation is available in jit_gru.py. I've implemented equivalents of torch.nn.GRUCell and torch.nn.GRU in that file. Look at the test cases that I've included in the implementation. Those should help you get started.

Bi-Directional GRUs

Support for bi-directional GRUs with variable input lengths was recently added (credits go to @elixir-code). This implementation is available separately in jit_bigru.py. See the included test cases in that file for example usage.

Demo Project

Checkout DeepNAG, which contains a GAN-based sequence generation model, as well as a non-adversarial sequence generator. The GAN-based sequence generator in the aforementioned repository is trained with the improved Wasserstein GAN loss function, and relies on the code from this repository.

Support/Citing

If you find our work useful, please consider starring this repository and citing our work:

@phdthesis{maghoumi2020dissertation,
  title={{Deep Recurrent Networks for Gesture Recognition and Synthesis}},
  author={Mehran Maghoumi},
  year={2020},
  school={University of Central Florida Orlando, Florida}
}

@misc{maghoumi2020deepnag,
      title={{DeepNAG: Deep Non-Adversarial Gesture Generation}}, 
      author={Mehran Maghoumi and Eugene M. Taranta II and Joseph J. LaViola Jr},
      year={2020},
      eprint={2011.09149},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contribution

I'm actively using this implementation, so contributions are greatly welcome as they help my work too. If you think you can improve this project, or implement something more efficiently, then feel free to submit pull requests!

License

This project is licensed under the MIT License.

About

Second-order differentiable PyTorch GRUs in JIT with TorchScript

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages