Skip to content

Caching of `x`, `y = f(x)`, and `log|det(J)|`

Compare
Choose a tag to compare
@stefanwebb stefanwebb released this 03 Feb 02:07
· 13 commits to main since this release

In this release, we add caching of intermediate values for Bijectors.

What this means is that you can often reduce computation by calculating log|det(J)| at the same time as y = f(x). It's also useful for performing variational inference on Bijectors that don't have an explicit inverse. The mechanism by which this is achieved is a subclass of torch.Tensor called BijectiveTensor that bundles together (x, y, context, bundle, log_det_J).

Special shout out to @vmoens for coming up with this neat solution and taking the implementation lead! Looking forward to your future contributions 🥳