Skip to content

radiradev/flowmatching-bdt

Repository files navigation

Flow-Matching BDT for tabular data

A minimal implementation of a genenerative model with flow matching for tabular data. No deep learning - uses XGBoost to learn the generative model.

The original implementation is available in forest-diffusion. Another implementation is available in the torchcfm library.

Unlike the implemenation in the forest-diffusion, we simplify the implemenatation by utilising XGBoost's ability to predict multiple regression outputs.

Installation

pip install flowmatching-bdt

Usage

from sklearn.datasets import make_moons
from flowmatching_bdt import FlowMatchingBDT

data, _ = make_moons(n_samples=1000, noise=0.1, random_state=0)
model = FlowMatchingBDT()

# train the model
model.fit(data)

# get new samples
num_samples = 1000
samples = model.predict(num_samples=num_samples)

If you'd like to do conditional generation:

import numpy as np
from sklearn.datasets import make_moons
from flowmatching_bdt import FlowMatchingBDT

data, labels = make_moons(n_samples=1000, noise=0.1, random_state=42)
model = FlowMatchingBDT()

# train the model
model.fit(data, conditions=labels)

# get new samples
num_samples = 1000
conditions = np.ones(num_samples)
samples = model.predict(num_samples=num_samples, conditions=conditions)

Resources

To learn more about flow matching for generative modelling check out these resources.

  1. Introduction to Flow Matching Tor Fjelde, Emilie Mathieu, Vincent Dutordoir
  2. Generating Tabular Data with XGBoost Alexia Jolicoeur (Author of the ForestFlow paper)

Citations

@inproceedings{jolicoeur2024generating,
  title={Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees},
  author={Jolicoeur-Martineau, Alexia and Fatras, Kilian and Kachman, Tal},
  booktitle={International Conference on Artificial Intelligence and Statistics},
  pages={1288--1296},
  year={2024},
  organization={PMLR}
}

Acknowlegements

This repository is inspired heavily and borrows parts from lucidrains (project structure) and torch-cfm.