You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
Thanks first for developing this nice package.
For the context, I intend to use diffrax to implement a custom Langevin-like dynamic, but my issue can be reduced to the following. Let's say I want to implement a simple $n$-dimensional Brownian motion: $$dX = dB$$
but it wouldn't work because of this line defining the vf-contr product for ControlTerms. What tensordot(vf, contr, axes=ndim(contr)) does is fully contracting tensors (on all the dimensions of contr), so in my case it would return a scalar $n dB$, whereas I would generally require piecewise (Hadamard) product vf * contr.
For now, the only way I found to implement piecewise ControlTerm product is to increase the dimensionality of the vector field, e.g. in that case, write diffusion = lambda t, y, args: jnp.eye(n), which is way more expensive ( $O(n^2)$ ) and will not scale to my applications. And I am not sure some jax.experimental.sparse matrices would help.
I understand matrix product, nay higher rank tensor products, may be required in some applications. This recent question, or this diffrax example of Neural SDE, have both matrix-valued diffusion vector field and vector-valued Brownian control. However, if I am not wrong, it seems that for that same reason of full tensor contraction, matrix product between matrix-valued vf and matrix-valued control is not currently easily implemented.
So my question would be: Did I miss a way to implement ControlTerm piecewise product?
I think it should be possible to implement it:
without much additional computational cost.
by conserving PyTree structures (without having to tree flatten everything manually) as diffrax already handles it well.
and optionally, by allowing other tensor contractions such as matrix-matrix matrix product.
I could not think of any einsum to replace tensordot(a,b,ndim(b)) that would fit well in all cases, but maybe having a way to specify which product _prod function to use in ControlTerm could be an idea? Or maybe I just missed a simple way to do everything above.
Thanks in advance!
The text was updated successfully, but these errors were encountered:
You want diffrax.WeaklyDiagonalControlTerm instead of just diffrax.ControlTerm. :)
maybe having a way to specify which product _prod function to use in ControlTerm could be an idea?
For this more general case, you can subclass diffrax.AbstractTerm and then implement the appropriate product you have in mind. (Just like the built-in terms!)
Also concerning the Brownian control, it seems changes in VirtualBrownianTree make diffrax not supporting reverse-time SDEs in new 0.5.0 version anymore. One get t0 must be strictly less than t1 error (and reversing t0 and t1 in the call does not help), whereas reverse-time ODE still works fine.
Ah, interesting point about the Brownian motion. FWIW since it's just a control then I think it shouldn't matter too much -- just switch them before passing them to the control. That said I'd be happy to add a PR that makes this "just work". (Ideally negating the generated samples if t0 > t1.)
Hi,
Thanks first for developing this nice package.
For the context, I intend to use diffrax to implement a custom Langevin-like dynamic, but my issue can be reduced to the following. Let's say I want to implement a simple$n$ -dimensional Brownian motion:
$$dX = dB$$
I can try doing
but it wouldn't work because of this line defining the vf-contr product for$n dB$ , whereas I would generally require piecewise (Hadamard) product
ControlTerm
s. Whattensordot(vf, contr, axes=ndim(contr))
does is fully contracting tensors (on all the dimensions ofcontr
), so in my case it would return a scalarvf * contr
.For now, the only way I found to implement piecewise$O(n^2)$ ) and will not scale to my applications. And I am not sure some
ControlTerm
product is to increase the dimensionality of the vector field, e.g. in that case, writediffusion = lambda t, y, args: jnp.eye(n)
, which is way more expensive (jax.experimental.sparse
matrices would help.I understand matrix product, nay higher rank tensor products, may be required in some applications. This recent question, or this diffrax example of Neural SDE, have both matrix-valued diffusion vector field and vector-valued Brownian control. However, if I am not wrong, it seems that for that same reason of full tensor contraction, matrix product between matrix-valued vf and matrix-valued control is not currently easily implemented.
So my question would be: Did I miss a way to implement
ControlTerm
piecewise product?I think it should be possible to implement it:
I could not think of any
einsum
to replacetensordot(a,b,ndim(b))
that would fit well in all cases, but maybe having a way to specify which product_prod
function to use inControlTerm
could be an idea? Or maybe I just missed a simple way to do everything above.Thanks in advance!
The text was updated successfully, but these errors were encountered: