You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I encountered an issue while using the PIDController. When using complex input in diffeqsolve, I cannot change the coefficient of the PIDController. The default coefficients works fine but when I try to change I get an error leading me to believe that the time variable is converted to complex type somewhere.
Here is a MWE to reproduce the error
# %% ==========================================================================# Imports# =============================================================================importjax.numpyasjnpimportdiffraxasdxfromjaxtypingimportScalar# %% ==========================================================================# Smallest working example, solving a complex ODE with diffrax# ODE to solve is dy/dt = iy with y a complex number# Setting up the ODE# =============================================================================defvector_field(t: Scalar, y: Scalar, *args):
dotY=1j*yreturndotYtsave=jnp.linspace(0.0, 10.0, 1000)
sim_time=tsave[-1]
solver=dx.Tsit5()
saveat=dx.SaveAt(ts=tsave)
y0=1.0j# %% ==========================================================================# Solve using the default PIDController coefficients# =============================================================================stepsize_controller=dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0, icoeff=1, dcoeff=0.0)
term=dx.ODETerm(vector_field=vector_field)
res_dx=dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
stepsize_controller=stepsize_controller)
# %% ==========================================================================# Solve using the non-zero P coefficient# =============================================================================stepsize_controller=dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0.3, icoeff=0.3, dcoeff=0.0)
term=dx.ODETerm(vector_field=vector_field)
res_dx=dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
stepsize_controller=stepsize_controller)
## Changing the coefficient in the PID controller seems to raise the error
ValueError: `body_fun` must have the same input and output structure. Difference is:
State(
y=c128[],
tprev=f64[],
- tnext=f64[],
+ tnext=c128[],
made_jump=bool[],
solver_state=(bool[], c128[]),
- controller_state=(bool[], bool[], f64[], c128[], c128[]),
+ controller_state=(bool[], bool[], c128[], c128[], c128[]),
result=EnumerationItem(
_value=i32[],
_enumeration=<class 'diffrax._solution.RESULTS'>
),
num_steps=i64[],
num_accepted_steps=i64[],
num_rejected_steps=i64[],
save_state=SaveState(
saveat_ts_index=i64[],
ts=_Buffer(
_array=f64[1000],
_pred=bool[],
_tag=<object object at 0x2e4dc0850>,
_makes_false_steps=False
),
ys=_Buffer(
_array=c128[1000],
_pred=bool[],
_tag=<object object at 0x2e4dc0850>,
_makes_false_steps=False
),
save_index=i64[]
),
dense_ts=None,
dense_infos=None,
dense_save_index=None
)
The text was updated successfully, but these errors were encountered:
Right! Complex numbers are only kind-of supported right now. The main blocker is how we've been waiting on an XLA bug fix, although I'm happy to say that this has recently been fixed.
Regardless, right now, Diffrax has pretty weak support for complex numbers. You should be able to use them within your vector field, as long as you decompose them into real and imaginary parts whenever you interface with Diffrax.
Hi, I encountered an issue while using the PIDController. When using complex input in diffeqsolve, I cannot change the coefficient of the PIDController. The default coefficients works fine but when I try to change I get an error leading me to believe that the time variable is converted to complex type somewhere.
Here is a MWE to reproduce the error
The text was updated successfully, but these errors were encountered: