Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inf values after triggering event function. #335

Open
KhayrullevJokhongir opened this issue Nov 29, 2023 · 4 comments
Open

inf values after triggering event function. #335

KhayrullevJokhongir opened this issue Nov 29, 2023 · 4 comments
Labels
question User queries

Comments

@KhayrullevJokhongir
Copy link

KhayrullevJokhongir commented Nov 29, 2023

I am solving a simple problem below using DiscreteTerminatingEvent. Once the event is triggered, the integration stops, but the solver returns 'inf' values for the time steps following the event's trigger time. Is there a way to avoid this, so that the solver returns function evaluations only for the time steps before the event-trigger time, similar to how solve_ivp in SciPy does?

"import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, Dopri5, DiscreteTerminatingEvent

def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
return jnp.array([d_prey, d_predator])

'''Define the terminating event function with two conditions'''
def terminating_event_fxn(state, args, **kwargs):
prey_population = state.y[0]
predator_population = state.y[1]

A = (prey_population < 5) | (predator_population > 15)
return A

'''Set up the ODE term, solver, and the initial conditions'''
term = ODETerm(vector_field)
solver = Dopri5()
t0 = 0
t1 = 140
dt0 = 0.1
y0 = jnp.array([10.0, 10.0])
args = (0.1, 0.02, 0.4, 0.02)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))

'''Define the terminating event'''
terminating_event = DiscreteTerminatingEvent(terminating_event_fxn)

'''Solve the ODE with the terminating event'''
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, discrete_terminating_event=terminating_event)

'''Plot the results'''
plt.plot(sol.ts, sol.ys[:, 0], label="Prey")
plt.plot(sol.ts, sol.ys[:, 1], label="Predator")
plt.legend()
plt.show()

print(sol.ys[:, 0].size)
print(sol.ts.shape)

"

@patrick-kidger
Copy link
Owner

I'm afraid not. All JAX arrays have to have a size known at compile time. However, the time of the event isn't known until runtime. As such Diffrax works by initialising an array of the appropriate size (here, of length given by saveat.ts) all filled with inf. Then it fills in this array as the integration progresses.

I hope that helps! :)

@patrick-kidger patrick-kidger added the question User queries label Nov 29, 2023
@KhayrullevJokhongir
Copy link
Author

It helps, thank you for quick reply :)

@KhayrullevJokhongir
Copy link
Author

Can I get some other values instead of inf, for example state of the system at last time step before the event triggered?

@patrick-kidger
Copy link
Owner

Not via Diffrax, but you could probably write some logic of your own to do that afterwards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants