Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex input in diffeqsolve with PIDController #389

Open
Ricky5389 opened this issue Mar 15, 2024 · 3 comments · Fixed by #391
Open

Complex input in diffeqsolve with PIDController #389

Ricky5389 opened this issue Mar 15, 2024 · 3 comments · Fixed by #391
Labels
bug Something isn't working question User queries

Comments

@Ricky5389
Copy link

Hi, I encountered an issue while using the PIDController. When using complex input in diffeqsolve, I cannot change the coefficient of the PIDController. The default coefficients works fine but when I try to change I get an error leading me to believe that the time variable is converted to complex type somewhere.
Here is a MWE to reproduce the error

# %% ==========================================================================
# Imports
# =============================================================================
import jax.numpy as jnp
import diffrax as dx
from jaxtyping import  Scalar

# %% ==========================================================================
# Smallest working example, solving a complex ODE with diffrax
# ODE to solve is dy/dt = iy with y a complex number
# Setting up the ODE
# =============================================================================
def vector_field(t: Scalar, y: Scalar, *args):
    dotY = 1j * y
    return dotY

tsave = jnp.linspace(0.0, 10.0, 1000)
sim_time = tsave[-1]

solver = dx.Tsit5()
saveat = dx.SaveAt(ts=tsave)
y0 = 1.0j
# %% ==========================================================================
# Solve using the default PIDController coefficients
# =============================================================================
stepsize_controller = dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0, icoeff=1, dcoeff=0.0)

term = dx.ODETerm(vector_field = vector_field)
res_dx = dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
        stepsize_controller=stepsize_controller)

# %% ==========================================================================
# Solve using the non-zero P coefficient
# =============================================================================
stepsize_controller = dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0.3, icoeff=0.3, dcoeff=0.0)

term = dx.ODETerm(vector_field = vector_field)
res_dx = dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
        stepsize_controller=stepsize_controller)


## Changing the coefficient in the PID controller seems to raise the error
ValueError: `body_fun` must have the same input and output structure. Difference is:
  State(
    y=c128[],
    tprev=f64[],
-   tnext=f64[],
+   tnext=c128[],
    made_jump=bool[],
    solver_state=(bool[], c128[]),
-   controller_state=(bool[], bool[], f64[], c128[], c128[]),
+   controller_state=(bool[], bool[], c128[], c128[], c128[]),
    result=EnumerationItem(
      _value=i32[],
      _enumeration=<class 'diffrax._solution.RESULTS'>
    ),
    num_steps=i64[],
    num_accepted_steps=i64[],
    num_rejected_steps=i64[],
    save_state=SaveState(
      saveat_ts_index=i64[],
      ts=_Buffer(
        _array=f64[1000],
        _pred=bool[],
        _tag=<object object at 0x2e4dc0850>,
        _makes_false_steps=False
      ),
      ys=_Buffer(
        _array=c128[1000],
        _pred=bool[],
        _tag=<object object at 0x2e4dc0850>,
        _makes_false_steps=False
      ),
      save_index=i64[]
    ),
    dense_ts=None,
    dense_infos=None,
    dense_save_index=None
  )
@patrick-kidger
Copy link
Owner

Right! Complex numbers are only kind-of supported right now. The main blocker is how we've been waiting on an XLA bug fix, although I'm happy to say that this has recently been fixed.

Regardless, right now, Diffrax has pretty weak support for complex numbers. You should be able to use them within your vector field, as long as you decompose them into real and imaginary parts whenever you interface with Diffrax.

@patrick-kidger patrick-kidger added bug Something isn't working question User queries labels Mar 15, 2024
@Randl
Copy link
Contributor

Randl commented Apr 22, 2024

@Ricky5389 can you try the latest dev? It should work on it.

@Ricky5389
Copy link
Author

Yes it works,
Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question User queries
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants