This repository provides a minimal example for training a flow matching model in a pretrained VAE's latent space to generate MNIST digits.
Denote the data distribution as
Here,
We sample
As a sanity check, we can visualize the original and reconstructed digits to ensure the pretrained VAE can embed the images as intended.
Original | Reconstructed |
---|---|
Epoch 1 | Epoch 50 | Epoch 100 |
---|---|---|
Epoch 200 | Epoch 300 | Final |
---|---|---|
The code is tested to work on PyTorch 1.13 and CUDA 11.7. The other packages can be installed with
pip install -r requirements.txt
Run:
python main.py
The transformer code is adapted from the DiT official repository: https://github.com/facebookresearch/DiT
The flow matching model is adapted from https://github.com/gle-bellier/flow-matching/tree/main