Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make custom integrators more similar to OrdinaryDiffEq.jl integrators? #1886

Open
DanielDoehring opened this issue Mar 22, 2024 · 5 comments
Labels
consistency Make Michael happy refactoring Refactoring code without functional changes

Comments

@DanielDoehring
Copy link
Contributor

DanielDoehring commented Mar 22, 2024

While looking at the elixir_euleracoustics_co-rotating_vortex_pair.jl I noticed that this can only be run with the integrators from OrdinaryDiffEq.jl because the EulerAcousticsCouplingCallback requires a step! function

@trixi_timeit timer() "Euler solver" step!(integrator_euler)

and an init function

integrator_euler = init(ode_euler, alg, save_everystep = false, dt = 1.0; kwargs...) # dt will be overwritten

which are not provided by the existing implementations.

We could, however, add these functions relatively easy, exemplified by SimpleIntegrator2N:

Essentially, we would need to provide a init function

function init(ode::ODEProblem, alg::T;
               dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
    u = copy(ode.u0)
    du = similar(u)
    u_tmp = similar(u)
    t = first(ode.tspan)
    iter = 0
    integrator = SimpleIntegrator2N(u, du, u_tmp, t, dt, zero(dt), iter, ode.p,
                                    (prob = ode,), ode.f, alg,
                                    SimpleIntegrator2NOptions(callback, ode.tspan;
                                                              kwargs...), false)

    # initialize callbacks
    if callback isa CallbackSet
        foreach(callback.continuous_callbacks) do cb
            error("unsupported")
        end
        foreach(callback.discrete_callbacks) do cb
            cb.initialize(cb, integrator.u, integrator.t, integrator)
        end
    elseif !isnothing(callback)
        error("unsupported")
    end

    return integrator
end

that is essentially the current solve function with only the last line changed

function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
u = copy(ode.u0)
du = similar(u)
u_tmp = similar(u)
t = first(ode.tspan)
iter = 0
integrator = SimpleIntegrator2N(u, du, u_tmp, t, dt, zero(dt), iter, ode.p,
(prob = ode,), ode.f, alg,
SimpleIntegrator2NOptions(callback, ode.tspan;
kwargs...), false)
# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
error("unsupported")
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
elseif !isnothing(callback)
error("unsupported")
end
solve!(integrator)
end

Then, the step! function could be implemented as

function step!(integrator::SimpleIntegrator2N)
    @unpack prob = integrator.sol
    @unpack alg = integrator
    t_end = last(prob.tspan)
    callbacks = integrator.opts.callback

    @assert !integrator.finalstep
    if isnan(integrator.dt)
        error("time step size `dt` is NaN")
    end

    # if the next iteration would push the simulation beyond the end time, set dt accordingly
    if integrator.t + integrator.dt > t_end ||
       isapprox(integrator.t + integrator.dt, t_end)
        integrator.dt = t_end - integrator.t
        terminate!(integrator)
    end

    # one time step
    integrator.u_tmp .= 0
    for stage in eachindex(alg.c)
        t_stage = integrator.t + integrator.dt * alg.c[stage]
        integrator.f(integrator.du, integrator.u, prob.p, t_stage)

        a_stage = alg.a[stage]
        b_stage_dt = alg.b[stage] * integrator.dt
        @trixi_timeit timer() "Runge-Kutta step" begin
            @threaded for i in eachindex(integrator.u)
                integrator.u_tmp[i] = integrator.du[i] -
                                      integrator.u_tmp[i] * a_stage
                integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
            end
        end
    end
    integrator.iter += 1
    integrator.t += integrator.dt

    # handle callbacks
    if callbacks isa CallbackSet
        foreach(callbacks.discrete_callbacks) do cb
            if cb.condition(integrator.u, integrator.t, integrator)
                cb.affect!(integrator)
            end
            return nothing
        end
    end

    # respect maximum number of iterations
    if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
        @warn "Interrupted. Larger maxiters is needed."
        terminate!(integrator)
    end
end

which is almost identical to the current solve! function

function solve!(integrator::SimpleIntegrator2N)
@unpack prob = integrator.sol
@unpack alg = integrator
t_end = last(prob.tspan)
callbacks = integrator.opts.callback
integrator.finalstep = false
@trixi_timeit timer() "main loop" while !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end
# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
# one time step
integrator.u_tmp .= 0
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
integrator.f(integrator.du, integrator.u, prob.p, t_stage)
a_stage = alg.a[stage]
b_stage_dt = alg.b[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.du[i] -
integrator.u_tmp[i] * a_stage
integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
end
end
end
integrator.iter += 1
integrator.t += integrator.dt
# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
end
# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
end
return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
end

For the version with init and step one could then implement solve as

function solve(ode::ODEProblem, alg::T;
               dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
               
integrator = init(ode, alg, dt, callbkck, kwargs...)

@unpack prob = integrator.sol

integrator.finalstep = false

@trixi_timeit timer() "main loop" while !integrator.finalstep
  step!(integrator)
end # "main loop" timer
  
return TimeIntegratorSolution((first(prob.tspan), integrator.t),
                              (prob.u0, integrator.u),
                              integrator.sol.prob)
end

which behaves as before.

@DanielDoehring DanielDoehring added consistency Make Michael happy refactoring Refactoring code without functional changes labels Mar 22, 2024
@ranocha
Copy link
Member

ranocha commented Mar 22, 2024

I originally implemented it like it is right now to simplify it as much as possible while keeping the option to use the same functionality we need from integrators provided by OrdinaryDiffEq.jl. I would be fine with these changes but @sloede needs to agree as well (since it makes the implementation more complex).

To be able to use custom integrators also for the cases you mentioned, we would need to specialize init, solve! etc. from https://github.com/SciML/CommonSolve.jl. This will be a real change since we use Trixi.solve right now with custom time integrators instead of the common solve version.

@DanielDoehring
Copy link
Contributor Author

DanielDoehring commented Mar 22, 2024

To be able to use custom integrators also for the cases you mentioned, we would need to specialize init, solve! etc. from https://github.com/SciML/CommonSolve.jl. This will be a real change since we use Trixi.solve right now with custom time integrators instead of the common solve version.

I think if we would really want to use solve from OrdinaryDiffEq.jl a lot more would have to be implemented which I think is unnecessary (at least at the moment). Thus I would stick to the Trixi.solve with the presented implementation.

@ranocha
Copy link
Member

ranocha commented Mar 22, 2024

But we would also need the same step! function as OrdinaryDiffEq.jl without depending on OrdinaryDiffEq.jl - or some special handling in the functions where we use it

@sloede
Copy link
Member

sloede commented Mar 23, 2024

IIUC, you do not want to make the implementation much more complicated, but just refactor solve and solve! into init and step!, right? Thus, if you were to implement it as described above, maybe with one or two additional in-source comments that make it easier to understand for novices, I wouldn't be opposed.

To some extent, ime integration is black magic anyways, and not that many people need to deal with its nitty gritty details except when they need something special - and in that case, more modularity is probably helpful

@DanielDoehring
Copy link
Contributor Author

IIUC, you do not want to make the implementation much more complicated, but just refactor solve and solve! into init and step!, right?

Yes that is right!

Thus, if you were to implement it as described above, maybe with one or two additional in-source comments that make it easier to understand for novices, I wouldn't be opposed.

I actually think that the more explicit version could even be helpful in illustrating that not the ODE-Algorithm, but actually the ODE-Integrator actually solves the problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
consistency Make Michael happy refactoring Refactoring code without functional changes
Projects
None yet
Development

No branches or pull requests

3 participants