Skip to content

Latest commit

 

History

History
35 lines (24 loc) · 1.29 KB

readme.md

File metadata and controls

35 lines (24 loc) · 1.29 KB

Variational Wasserstein gradient flow

This is the official Python implementation of the paper Variational Wasserstein gradient flow (paper on arXiv) by Jiaojiao Fan, Qinsheng Zhang, Amirhossein Taghvaei and Yongxin Chen.

The repository contains reproducible PyTorch source code for computing Wasserstein gradient flow with variational estimation of target functional in high dimension.

Repository structure

The codebase is tested on CUDA version 11.4 and PyTorch version 1.10.1+cu113.

To reproduce the experiments except image geneation, go to toy folder and follow the instructions in toy/README.md

cd toy

To reproduce the experiment of image geneation, go to image folder and follow the instructions in image/README.md

cd image

Citation

@inproceedings{
  fan2022variational,
  title={Variational Wasserstein gradient flow},
  author={Fan, Jiaojiao and Zhang, Qinsheng and Taghvaei, Amirhossein and Chen, Yongxin},
  booktitle={International Conference on Machine Learning},
  year={2022}
}