Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KL divergence terms for Latent SDEs #402

Open
wants to merge 15 commits into
base: dev
Choose a base branch
from

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented Apr 17, 2024

Addresses #401. Revives #104. Based on that PR, I made the minimal requirements to get it up to current version (e.g. taking callables instead of ODE terms since we can't make these .vf becuase _broadcast_and_upcast requires that aug_y and drift(aug_y) are the same shape, but they aren't).

@lockwo
Copy link
Contributor Author

lockwo commented Apr 17, 2024

Before going further (there is a lot I am going to improve/polish) I wanted to check with your thoughts on the general approach of KL being terms and exposing the user to a function that converts their problem. An alternative could be something like in torchsde where it's part of the intregration method, i.e. the user flags it at integration time.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On make these terms: I don't have super strong feelings, but compared to the original PR we have now more clearly defined what a term is in Diffrax, and I think there are other points on the design space.

To be precise: given a diffeq of the form

dy = f(y, z) da + g(y, z) db
dz = h(y, z) da + k(y, z) db

then this would be represented in Diffrax as

terms = (
    MultiTerm(f, g),
    MultiTerm(h, k),
)

In general: everything inside a MultiTerm(...) is all applied to the same dfoo. For example the SDE-specific solvers consume a MultiTerm[ODETerm, AbstractTerm], for the drift and diffusion.
Meanwhile the PyTree structure of terms themselves corresponds to different dfoo and dbar. For example semi-implicit Euler takes a pair of (AbstractTerm, AbstractTerm), corresponding to the two components that are being evolved.

In this case there is an argument that the extra KL-divergence term should really correspond to a new dfoo, and that as such the correct thing to do is to instead replace terms with (terms, kl_term), and then provide a wrapper solver which understands this alternative term structure.

diffrax/_kl_term.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

On the topic of Lineax: indeed, this should definitely make handling PyTrees much easier.

@lockwo lockwo changed the base branch from main to dev April 23, 2024 19:02
@lockwo
Copy link
Contributor Author

lockwo commented Apr 24, 2024

I think your idea makes a lot of sense, and I made a fair amount of progress on the solver wrapper approach.

@lockwo lockwo marked this pull request as ready for review April 27, 2024 06:26
@lockwo
Copy link
Contributor Author

lockwo commented Apr 27, 2024

Ok, I polished things up. I went with a sort of hybrid approach where the users specifies the SDEs as you described, then just wraps a solver and everything works smoothly. However, I did create internal terms, in order to get an arbitrary solver to integrate through the KL computation, that was the best way I could think of to do so, but they are completely hidden from the user. I also added the example (can be modified to add more text, or remove pmap although I do like having an example with distribution especially since its painfully slow without it) and a test and updated the docs. Taking it off draft now since its a real PR.

@frankschae
Copy link

This is a very cool feature/example! It looks like one needs to specify

levy_area=diffrax.BrownianIncrement

in diffrax.UnsafeBrownianPath

@lockwo
Copy link
Contributor Author

lockwo commented May 8, 2024

Thanks @frankschae , good catch!

@lockwo
Copy link
Contributor Author

lockwo commented May 10, 2024

The test failures are all just the safe map 0.4.27 stuff

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants