New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Events #387
base: dev
Are you sure you want to change the base?
Events #387
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very clean! I've gone through and left some comments.
diffrax/_integrate.py
Outdated
event_result: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]] = None | ||
event_mask: Optional[PyTree[BoolScalarLike]] = None | ||
dense_info_for_event: Optional[DenseInfo] = None | ||
tprevprev: Optional[FloatScalarLike] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think default arguments are needed here.
(This is also just a general good-practice-rule-of-thumb: try to avoid default arguments where possible, as they're a common source of unexpected behaviour. So in practice this usually means putting them on public interfaces.)
diffrax/_integrate.py
Outdated
@@ -227,14 +238,45 @@ def _maybe_static(static_x: Optional[ArrayLike], x: ArrayLike) -> ArrayLike: | |||
return x | |||
|
|||
|
|||
def _is_cond_fn(x: Any) -> bool: | |||
return isinstance(x, Callable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can replace this whole function with just the built-in callable
.
diffrax/_integrate.py
Outdated
new_event_result: Union[BoolScalarLike, RealScalarLike], | ||
) -> BoolScalarLike: | ||
return jnp.sign(jnp.array(old_event_result, float)) != jnp.sign( | ||
jnp.array(new_event_result, float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use the Python builtin float
(or int
etc.) as a dtype. This will actually give different behaviour on different platforms (MacOS vs Windows etc.) for some strange reason.
I think in this case you're probably looking for dtype=jnp.result_type(old_event_result.dtype, jnp.float32)
?
diffrax/_integrate.py
Outdated
@@ -309,6 +351,14 @@ def body_fun_aux(state): | |||
# everything breaks.) See #143. | |||
y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error) | |||
|
|||
# Save info for event handling | |||
if event is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: remove the not
and flip teh branches.
diffrax/_integrate.py
Outdated
_, _, dense_info_for_event, _, _ = solver.step( | ||
terms, | ||
tprev, | ||
tnext, | ||
y0, | ||
args, | ||
solver_state, | ||
made_jump, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will actually evaluate the function, I think essentially pointlessly (as it will be re-evaluated in the loop).
I think using jax.eval_shape
or eqx.filter_eval_shape
will allow you to get the structure/dtype/shape of dense_info_for_event
without having to actually do any work at runtime. (And without having to compile solver.step
an additional time during compilation time either, which is a nontrivial concern.)
diffrax/_integrate.py
Outdated
if event is not None: | ||
|
||
def _call_event(_cond_fn): | ||
return _cond_fn( | ||
init_state, | ||
y=y0, | ||
solver=solver, | ||
stepsize_controller=stepsize_controller, | ||
saveat=saveat, | ||
t0=t0, | ||
t1=t1, | ||
dt0=dt0, | ||
max_steps=max_steps, | ||
terms=terms, | ||
args=args, | ||
) | ||
|
||
event_result = jtu.tree_map(_call_event, event.cond_fn, is_leaf=_is_cond_fn) | ||
event_mask = jtu.tree_map(lambda x: False, event.cond_fn, is_leaf=_is_cond_fn) | ||
init_state = eqx.tree_at( | ||
lambda s: s.event_result, init_state, event_result, is_leaf=_is_none | ||
) | ||
init_state = eqx.tree_at( | ||
lambda s: s.event_mask, init_state, event_mask, is_leaf=_is_none | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(a) I think if you put this before init_state = ...
then you can just pass in this information during initialisation, without tree_at
.
(b) I think again you can use eval_shape
to initialise arrays of the appropriate sort without having to actually evaluate each cond_fn
.
diffrax/_integrate.py
Outdated
if event is not None: | ||
event_mask = final_state.event_mask | ||
event_happened = _event_happened(event_mask) | ||
tevent = final_state.tprev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: prefer single-assignment form:
if event.root_finder is None:
tevent = final_state.tprev
else:
# all the rest of the branch below
tevent = ...
one of these software best-practices for writing easy-to-read code.
Thank you for the comments, Patrick. I have gone through and made changes accordingly. The only thing I would disagree with is your comment on not having to evaluate |
Just letting you know that I've not forgotten about this! I'm trying to focus on getting #344 in first, and then I'm hoping to return to this. They're both quite large changes so I don't want them to step on each other's toes. |
76a9441
to
34cbe5c
Compare
Okay, #344 is in! I'd love to get this in next. I appeciate that's rather a lot of merge conflicts. If you're able to rebase on to the latest |
The main changes are: 1. Added the generic Event class: ``` class Event: event_function: PyTree[EventFn] root_finder: Optional[optx.AbstractRootFinder] = None ``` EventFn is defined as: ``` class EventFn(eqx.Module): cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]] transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = ( lambda x: x ) ```` 2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class. 3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve. Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs. The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn. For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events. No event cases in root-finding At the end of the root-fining step (L1146 in _integrate.py), I changed: ``` return jtu.tree_map( _call_real, event.event_fn, final_state.event_result, final_state.event_compare, is_leaf=_is_event_fn, ) ``` to ``` results = jtu.tree_map( _call_real, event.event_fn, final_state.event_result, final_state.event_compare, is_leaf=_is_event_fn, ) results_ravel, _ = jfu.ravel_pytree(results) return jnp.where(event_happened, results_ravel, final_state.tprev - t) ``` Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find. Added EventFn and Event to diffrax/__init__.py Added tests for new event handling I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives. Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip: - test_event.py::test_discrete_terminate1 - test_event.py::test_discrete_terminate2 - test_event.py::test_event_backsolve - test_adjoint.py::test_implicit In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above. Deleted old event implementation I deleted the following two classes: - diffrax._event.DiscreteTerminatingEvent - diffrax._event.SteadyStateEvent These were also removed from the diffrax.__init__.py Minor changes to event hadnling The changes are the following: - Tweaked the event API and got rid of the EventFn class. Now there is only an Event class: ``` class Event(eqx.Module): cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]] root_finder: Optional[optx.AbstractRootFinder] = None ``` - Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t. - Removed all cases where jtu.ravel_pytree was used. - Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class. - Updated tests.py and __init__.py to reflect the changes. Minor changes for simplicity I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons. Made changes according to comments on patrick-kidger#387 No event case Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.
Perfect, I rebased and squashed all the commits into a single big one. Quite a few tests are failing when I run it locally now, but I just wanted to update it so you could have a look. Two thoughts since we last touched base:
|
Hi! Maybe i am using it wrong but at the moment I can't get the root_finder to do smth. I get the same event times when i use the Newton Method as a root finder or just 45 def test_continuous_event_time():
46 term = diffrax.ODETerm(lambda t, y, args: 1.0)
47 solver = diffrax.Tsit5()
48 t0 = 0
49 t1 = jnp.inf
50 dt0 = 1.0
51 y0 = -10.0
52
53 def cond_fn(state, y, **kwargs):
54 assert isinstance(state.y, jax.Array)
55 return y
56
57 #root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
58 root_finder = None
59 event = diffrax.Event(cond_fn, root_finder)
60 sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
61 assert jnp.all(jnp.isclose(cast(Array, sol.ts), 10.0, 1e-5)) Or in my code the resulting event times do not change if having a root finder or not: 7 def bouncing_ball():
8 g = 9.81
9 damping = 0.8
10 max_bounces = 10
11 vx_0 = 5.0
12
13 def dynamics(t, y, args):
14 x, y, vx, vy = y
15 dxdt = vx
16 dydt = vy
17 dvxdt = 0
18 dvydt = -g
19 return jnp.array([dxdt, dydt, dvxdt, dvydt])
20
21 def cond_fn(state, y, **kwargs):
22 return y[1] < 0
23
24 y0 = jnp.array([0.0, 10.0, vx_0, 0.0])
25 t0, t1 = 0, float('inf')
26
27 times = []
28 states = []
29
30 for _ in range(max_bounces):
31 root_finder = optx.Newton(1e-5, 1e-5)
32 #root_finder = None
33 event = Event(cond_fn, root_finder=root_finder)
34 solver = Tsit5()
35
36 sol = diffeqsolve(ODETerm(dynamics), solver, t0, t1, 0.01, y0, event=event)
37
38 t0 = sol.ts[-1]
39 last_y = sol.ys[-1]
40 y0 = last_y * jnp.array([1, 0, 1, -damping])
41 times.append(sol.ts)
42 states.append(y0)
43
44
45 return jnp.array(times), jnp.array(states) Thanks for the help! (Also sorry If this is the wrong place to ask for this, just let me know where to write this) :) |
Ah, thanks for mentioning this. This is essentially due to the fact that the solution to the ODE is linear and the fact that
In your example the cond_fn returns a boolean. In this case the returned event time is exactly the first step of the solver for which If you want continuous event times, you should specify a real-valued condition function. In your bouncing ball example, this would simply correspond to setting: def cond_fn(state, y, **kwargs):
return y[1] |
Ah i see! Was just a bit confusing that both boolean and comparison with 0 is possible. I tried a real valued cond_fn initially but then it was every time directly triggered at t=0 bc i reset the state to 0 at every jump. I guess setting it to a small value should work. Thank you! :) |
Regarding my confusion: Maybe this is a bad idea but maybe it would make sense to have a warning if both a boolean |
If I pass multiple |
Hmm, I'm not too sure about this. There might be cases where you have one real-valued event function and one boolean. E.g., in the bouncing ball example we might want to add an extra function to
The solution returned by |
Agreed on the first point. Happy to add a warning if all events are Boolean, though -- no strong feelings. |
Hi folks! Very excited about this PR, as I'm thinking about quantum jump applications for dynamiqs. I'm unfortunately running into an error If I try to pass an option different from import diffrax as dx
import optimistix as optx
import jax
import jax.numpy as jnp
term = dx.ODETerm(lambda t, y, args: y)
solver = dx.Tsit5()
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = 1.0
ts = jnp.arange(t0, t1, dt0)
def cond_fn(state, y, **kwargs):
assert isinstance(state.y, jax.Array)
return y - jnp.exp(1.0)
fn = lambda t, y, args: y
subsaveat_a = dx.SubSaveAt(ts=ts, fn=fn) # save solution regularly
subsaveat_b = dx.SubSaveAt(t1=True) # save last state
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
term, solver, t0, t1, dt0, y0, saveat=saveat, event=event
) This runs into an error on line 1204 of ys = jtu.tree_map(lambda _y, _yevent: _y.at[-1].set(_yevent), ys, yevent)
ValueError: Expected list, got Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>. Thanks! |
The main changes are: 1. Added the generic Event class: ``` class Event: event_function: PyTree[EventFn] root_finder: Optional[optx.AbstractRootFinder] = None ``` EventFn is defined as: ``` class EventFn(eqx.Module): cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]] transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = ( lambda x: x ) ```` 2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class. 3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve. Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs. The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn. For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events. No event cases in root-finding At the end of the root-fining step (L1146 in _integrate.py), I changed: ``` return jtu.tree_map( _call_real, event.event_fn, final_state.event_result, final_state.event_compare, is_leaf=_is_event_fn, ) ``` to ``` results = jtu.tree_map( _call_real, event.event_fn, final_state.event_result, final_state.event_compare, is_leaf=_is_event_fn, ) results_ravel, _ = jfu.ravel_pytree(results) return jnp.where(event_happened, results_ravel, final_state.tprev - t) ``` Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find. Added EventFn and Event to diffrax/__init__.py Added tests for new event handling I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives. Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip: - test_event.py::test_discrete_terminate1 - test_event.py::test_discrete_terminate2 - test_event.py::test_event_backsolve - test_adjoint.py::test_implicit In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above. Deleted old event implementation I deleted the following two classes: - diffrax._event.DiscreteTerminatingEvent - diffrax._event.SteadyStateEvent These were also removed from the diffrax.__init__.py Minor changes to event hadnling The changes are the following: - Tweaked the event API and got rid of the EventFn class. Now there is only an Event class: ``` class Event(eqx.Module): cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]] root_finder: Optional[optx.AbstractRootFinder] = None ``` - Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t. - Removed all cases where jtu.ravel_pytree was used. - Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class. - Updated tests.py and __init__.py to reflect the changes. Minor changes for simplicity I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons. Made changes according to comments on patrick-kidger#387 No event case Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.
Previously, updating the last element of ys and ts did not handle the case where multiple `SubSaveAt`s were used. This is now fixed by adding a `jtu.tree_map` in the appropriate place.
Ah, yes I forgot to handle the case where multiple |
Thanks for the quick response! So that fixed the example I posted, however I am still running into issues on slightly more complicated examples more in line with how dynamiqs actually calls import diffrax as dx
import optimistix as optx
import jax
import jax.numpy as jnp
import equinox as eqx
from jax import Array
term = dx.ODETerm(lambda t, y, args: y + t)
solver = dx.Tsit5()
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[1.0], [0.0]])
ts = jnp.arange(t0, t1, dt0)
def cond_fn(state, y, **kwargs):
assert isinstance(state.y, jax.Array)
norm = jnp.einsum("ij,ij->", y, y)
return norm - jnp.exp(1.0)
class Saved(eqx.Module):
y: Array
y2: Array
def save_fn(t, y, args):
ynorm = jnp.einsum("ij,ij->", y, y)
return Saved(y, jnp.array([ynorm, 3 * ynorm]))
subsaveat_a = dx.SubSaveAt(ts=ts, fn=save_fn) # save solution regularly
subsaveat_b = dx.SubSaveAt(t1=True) # save last state
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
term, solver, t0, t1, dt0, y0, saveat=saveat, event=event
) This runs into ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 2) shape=(2,) Interestingly, the code runs without errors if I replace |
You're right I did not account for the fact that |
Indeed that fixed my MWE! I hate to be such a pain but I am now running into another issue, here is an example that is now much closer to the actual code I am interested in running. import diffrax as dx
import optimistix as optx
import jax.numpy as jnp
L_op = 0.1 * jnp.array([[0.0, 1.0],
[0.0, 0.0]], dtype=complex)
H = 0.0 * L_op
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[0.0], [1.0]], dtype=complex)
def vector_field(t, state, _args):
L_d_L = jnp.transpose(L_op) @ L_op
new_state = -1j * (H - 1j * 0.5 * L_d_L) @ state
return new_state
def cond_fn(state, **kwargs):
psi = state.y
prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
return prob - 0.95
term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
term, dx.Tsit5(), t0, t1, dt0, y0, event=event
) This runs into equinox._errors.EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this. It's possible this could be due to my use of complex numbers, which as I understand are only partly supported in diffrax? However with the previous |
Complex number support is definitely still iffy. Can you try reproducing this without using them? You can still solve the same equation, mathematically speaking, just by splitting things into separate real and imaginary components. |
Right, here is the same example using the complex->real isomorphism described e.g. here (see Eq. (9)). I am getting the same error as before, so it seems then this is not a complex number issue import diffrax as dx
import optimistix as optx
import jax.numpy as jnp
def mat_cmp_to_real(matrix):
re_matrix = jnp.real(matrix)
im_matrix = jnp.imag(matrix)
top_row = jnp.hstack((re_matrix, -im_matrix))
bottom_row = jnp.hstack((im_matrix, re_matrix))
return jnp.vstack((top_row, bottom_row))
def vec_cmp_to_real(vector):
re_vec = jnp.real(vector)
im_vec = jnp.imag(vector)
return jnp.vstack((re_vec, im_vec))
L_op = 0.1 * jnp.array([[0.0, 1.0],
[0.0, 0.0]], dtype=complex)
L_d_L = jnp.transpose(L_op) @ L_op
H = 0.0 * L_op
_prop = -1j * (H - 1j * 0.5 * L_d_L)
_y0 = jnp.array([[0.0], [1.0]], dtype=complex)
prop = mat_cmp_to_real(_prop)
y0 = vec_cmp_to_real(_y0)
t0 = 0
t1 = 100.0
dt0 = 1.0
def vector_field(t, state, _args):
new_state = prop @ state
return new_state
def cond_fn(state, **kwargs):
psi = state.y
prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
return prob - 0.95
term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
term, dx.Tsit5(), t0, t1, dt0, y0, event=event
) |
- Semantic change: boolean events now trigger when they become truthy (before they occurred when they swap being falsy<->truthy). Note that this required twiddling around a few things as previously it was impossible for an event to trigger on the first step; now they can. - Semantic change: event functions now have the signature `(t, y, args *, terms, solver, **etc)` for consistency with vector fields and with `SaveAt(fn=...)`. - Feature: now backward-compatible with the old discrete terminating events. - Feature: added `diffrax.steady_state_event`. - Bugfix: the final `t` and `y` from an event are now saved in the correct index of `ts` and `ys`, rather than just always being saved at index `-1`. - Bugfix: at one point `args` referred to the `args` coming from a root find rather than the overall `diffeqsolve`. - Bugfix: the current `state.tprev` was used instead of the previous state's `tnext`. (These are usually but not always the same -- in particular when around jumps.) - Bugfix: added some checks when the condition function of an event does not return a bool/float scalar. - Performance: includes a fastpath for skipping the rootfind if no events are triggered. - Performance: now avoiding tracing for the shape of `dense_info` twice when using adaptive step size controllers alongside events. - Performance: avoided quadratic loop for figuring out what was the first event to trigger. - Chore: added support for the possibility of the final root find (for the time of the event) failing. - Chore: removed some dead code (`_bool_event_gradient`). - Chore: removed references in the docs to the old `discrete_terminating_event`. In addition, some drive-bys: - Fixed warnings about pending deprecations `jnp.clip(..., a_min=..., a_max=...)`. - Had `aux_stats` (in `_integrate.py`) forward to the overall output statistics. In practice this is empty but it's worth doing for the future.
Thankyou @dkweiss31! Anyway, as promised! Getting this in is my next priority. As such I've gone through and submitted a PR against this branch here. I don't claim that everything I've done is necessarily correct, so @cholberg I'd appreciate a review! :D |
@dkweiss31, sorry for being a little unresponsive these last couple of days! I will be quite busy the next 2-3 days, but will try to look at it asap. @patrick-kidger, oh, that's great! But yea, deadline coming up so can't promise that I'll have time to look at it before Thursday. Will make it a priority afterwards, though :) |
@patrick-kidger FWIW both of my examples run without errors using your branch 387-tweaks! Only change I had to make was changing the signature of my conditional function to def cond_fn(t, y, *args, **kwargs): Moreover, the behavior is as expected and the solve terminates at e.g. 0.95. (Something interesting is that if I ask it to save intermediate values, the saved output includes values for times after the termination time. I assume this is as expected, as the code then backtracks to the point where |
I'm glad it now works! I have no idea what I changed to make that happen, although I did have a variety of small bugfixes as well. I think saving anything after the termination time should be considered a bug though :) I knew I must have missed something! Assuming the above MWE demonstrates this, I'll leave this to @cholberg to consider post-NeurIPS-deadline :) |
One more thing (Sorry might just be my bad approach). If i want to have multiple events which happen each at a specific point in time, only the last one passed to the Event class seems to be recognized: 11 def test_continuous_terminate2():
12 term = diffrax.ODETerm(lambda t, y, args: y)
13 solver = diffrax.Tsit5()
14 t0 = 0
15 t1 = jnp.inf
16 dt0 = 1
17 y0 = 1.0
18
19 event_times = [3, 4, 10]
20 cond_fns = [lambda state, **kwargs: state.tprev - t for t in event_times]
21
22 event = diffrax.Event(cond_fn=cond_fns)
23 sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) In this case Additionally: Also i get an error in this case if the time dependent event happens (t=9) before the state-value dependent event (t=10): 11 def test_continuous_two_events():
12 term = diffrax.ODETerm(lambda t, y, args: 1.0)
13 solver = diffrax.Tsit5()
14 t0 = 0
15 t1 = jnp.inf
16 dt0 = 1.0
17 y0 = -10.0
18
19 event_time = 9
20
21 def cond_fn_1(state, y, **kwargs):
22 assert isinstance(state.y, jax.Array)
23 return y
24
25 def cond_fn_2(state, y, **kwargs):
26 assert isinstance(state.y, jax.Array)
27 return state.tprev - event_time
28
29
30 root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
31 event = diffrax.Event([cond_fn_1, cond_fn_2], root_finder)
32 sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) -> Why do you think is this the case and what can I do to change this? I assume if I just pick t1 as the event time it still has the problem that it could overshoot in time by some value between 0 and the step size? |
For the second one that's probably on us. For the first one, you are running afoul of https://docs.astral.sh/ruff/rules/function-uses-loop-variable/ |
Cool thanks! Wow I didnt know about this :D thanks for the help! |
@dkweiss31, @patrick-kidger: I ran your MWE above with |
@LuggiStruggi: With the latest commits, your second example seems to work for me :) |
cool thanks :) sorry maybe i didnt pull recently enough |
Hey friends, pulling from the most recent version of cholberg:dev the following example fails: import diffrax as dx
import optimistix as optx
import jax.numpy as jnp
import equinox as eqx
from jax import Array
L_op = 0.1 * jnp.array([[0.0, 1.0],
[0.0, 0.0]], dtype=complex)
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[0.0], [1.0]], dtype=complex)
ts = jnp.arange(t0, t1, dt0)
class Saved(eqx.Module):
y: Array
def save_fn(t, y, args):
ynorm = jnp.einsum("ij,ij->", y, y)
return Saved(jnp.array([ynorm,]))
subsaveat_a = dx.SubSaveAt(ts=ts, fn=save_fn)
subsaveat_b = dx.SubSaveAt(t1=True)
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])
def vector_field(t, state, _args):
L_d_L = jnp.transpose(L_op) @ L_op
new_state = -1j * (- 1j * 0.5 * L_d_L) @ state
return new_state
def cond_fn(t, y, *args, **kwargs):
psi = y
prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
return prob - 0.95
term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
term, dx.Tsit5(), t0, t1, dt0, y0, event=event, saveat=saveat
) on line 715 of save_index = final_state.save_state.save_index - 1
AttributeError: 'list' object has no attribute 'save_index' |
Updates to how events are handled in diffrax. The main changes are:
Event
.cond_fn
are supported. An event is triggered whenever one of them changes sign.cond_fn
and a root_finder is provided.Some things that still might require a little thinking about: