Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solving simple dynamics: ControlTerm piecewise product #358

Open
hsimonfroy opened this issue Jan 22, 2024 · 3 comments
Open

Solving simple dynamics: ControlTerm piecewise product #358

hsimonfroy opened this issue Jan 22, 2024 · 3 comments
Labels
question User queries

Comments

@hsimonfroy
Copy link

Hi,
Thanks first for developing this nice package.

For the context, I intend to use diffrax to implement a custom Langevin-like dynamic, but my issue can be reduced to the following. Let's say I want to implement a simple $n$-dimensional Brownian motion:
$$dX = dB$$

I can try doing

diffusion = lambda t, y, args: jnp.ones(n)
brownian_motion = VirtualBrownianTree(t0, t1, tol=dt/2, shape=(n,), key=seed)
terms = ControlTerm(diffusion, brownian_motion)

but it wouldn't work because of this line defining the vf-contr product for ControlTerms. What tensordot(vf, contr, axes=ndim(contr)) does is fully contracting tensors (on all the dimensions of contr), so in my case it would return a scalar $n dB$, whereas I would generally require piecewise (Hadamard) product vf * contr.

For now, the only way I found to implement piecewise ControlTerm product is to increase the dimensionality of the vector field, e.g. in that case, write diffusion = lambda t, y, args: jnp.eye(n), which is way more expensive ( $O(n^2)$ ) and will not scale to my applications. And I am not sure some jax.experimental.sparse matrices would help.

I understand matrix product, nay higher rank tensor products, may be required in some applications. This recent question, or this diffrax example of Neural SDE, have both matrix-valued diffusion vector field and vector-valued Brownian control. However, if I am not wrong, it seems that for that same reason of full tensor contraction, matrix product between matrix-valued vf and matrix-valued control is not currently easily implemented.

So my question would be: Did I miss a way to implement ControlTerm piecewise product?
I think it should be possible to implement it:

  • without much additional computational cost.
  • by conserving PyTree structures (without having to tree flatten everything manually) as diffrax already handles it well.
  • and optionally, by allowing other tensor contractions such as matrix-matrix matrix product.

I could not think of any einsum to replace tensordot(a,b,ndim(b)) that would fit well in all cases, but maybe having a way to specify which product _prod function to use in ControlTerm could be an idea? Or maybe I just missed a simple way to do everything above.

Thanks in advance!

@patrick-kidger
Copy link
Owner

You want diffrax.WeaklyDiagonalControlTerm instead of just diffrax.ControlTerm. :)

maybe having a way to specify which product _prod function to use in ControlTerm could be an idea?

For this more general case, you can subclass diffrax.AbstractTerm and then implement the appropriate product you have in mind. (Just like the built-in terms!)

@patrick-kidger patrick-kidger added the question User queries label Jan 25, 2024
@hsimonfroy
Copy link
Author

hsimonfroy commented Feb 5, 2024

Thanks, works fine!

The getting started SDE part redirects to Terms page so I should have seen it ;)

Also concerning the Brownian control, it seems changes in VirtualBrownianTree make diffrax not supporting reverse-time SDEs in new 0.5.0 version anymore. One get t0 must be strictly less than t1 error (and reversing t0 and t1 in the call does not help), whereas reverse-time ODE still works fine.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 5, 2024

Ah, interesting point about the Brownian motion. FWIW since it's just a control then I think it shouldn't matter too much -- just switch them before passing them to the control. That said I'd be happy to add a PR that makes this "just work". (Ideally negating the generated samples if t0 > t1.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants