Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass full integrator instead of parameters #116

Open
devmotion opened this issue Jul 2, 2019 · 5 comments
Open

Pass full integrator instead of parameters #116

devmotion opened this issue Jul 2, 2019 · 5 comments

Comments

@devmotion
Copy link
Member

As discussed in SciML/DiffEqProblemLibrary.jl#39, especially for the history function it seems reasonable to pass the full integrator as argument instead of only the parameters, i.e., having h(integrator, t) instead of h(p, t) and also f(u, h, integrator, t) instead of f(u, h, p, t). This would enable the user to write generic history functions with correct output types (see the discussion in the PR) and hopefully allow to simplify the implementation in DelayDiffEq.

According to @ChrisRackauckas

we should have a common arg for using integrator instead of p, and then we just need to make every package handle that well.

I think we should approach this issue slightly differently. A user has to decide whether to pass around the integrator or only the parameters already when implementing f (or h), i.e., it is a property that does not depend on the numerical algorithm but rather of the differential equation function. Hence I guess it would make sense to handle this issue by modifying DiffEqFunctions instead of different algorithms. We could replace

abstract type AbstractDiffEqFunction{iip} <: Function end

with

abstract type AbstractDiffEqFunction{iip,unpackparams} <: Function end

and then define, e.g.,

(f::ODEFunction{true,unpackparams})(du, u, integrator, t) where unpackparams = unpackparams ? f.f(du, u, get_p(integrator), t) : f.f(du, u, integrator, t).

In that way, we just have to implement get_p for every integrator (which would be integrator.p by default) and could always pass integrator in every package.

@ChrisRackauckas
Copy link
Member

Yes, this makes sense. I am a little worried about compile times, but maybe it all just quickly compiles away.

@devmotion
Copy link
Member Author

devmotion commented Jul 2, 2019

Yes, hopefully the compiler is smart enough.

However, there's another issue: in the same way we have to pass around integrator to the tgrad, analytic, etc. functions (or not). Of course, this could be ensured on the level of DiffEqFunction in the same way as in the example above by using overloads such as f(Val{:analytic}, ...). But since we switched away from this form I guess that's not a good idea 😄

Alternatively, one could define functions such as

function DiffEqBase.analytic(f::ODEFunction{iip,unpack}, u, integrator, t) where {iip,unpack}
    has_analytic(f) || error("analytical solution is not defined")

    unpack ? f.analytic(u, get_p(integrator), t) : f.analytic(u, integrator, t)
end 

for all such overloads, but I don't know if this makes any difference.

I still like the idea of attacking this problem on the lowest level, but of course an alternative would be to explicitly define p before every (chunk of) function calls, e.g., by defining

function perform_step!(integrator, cache::BS3ConstantCache)
    p = unpack_params(integrator, integrator.f)
    .....
end

unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,false}) where iip = integrator
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,true}) where iip = get_p(integrator)

@ChrisRackauckas
Copy link
Member

We can also hack it with getproperty overloading

@devmotion
Copy link
Member Author

I'm working on a prototype for ODEFunction and I still hope that not too many changes are necessary in OrdinaryDiffEq.

However, I'm not sure how to deal with the fact that p is used to construct the cache in https://github.com/JuliaDiffEq/OrdinaryDiffEq.jl/blob/master/src/solve.jl#L246 before the ODEIntegrator exists. As far as I can see, p is mostly/only used to construct the Jacobian w.r.t u for the nonlinear solvers in lines such as https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/nlsolve/utils.jl#L195 to evaluate f.jac(uprev, p, t). I mean, if jac is given we want to use it but I don't know how to retrieve its type if it expects a full integrator.

Can we get around this problem somehow by not caching W but passing it around when it's created?

@devmotion
Copy link
Member Author

Since passing around the integrator in OrdinaryDiffEq is not completely straightforward (at least it seems to me), I started playing around with something that's more centered around the use case in DelayDiffEq. One idea was to use getproperty overloading such that all calls of @unpack f = integrator or integrator.f in OrdinaryDiffEq return an ODE Function with a history that is built on integrator, similar to the following simple example:

using DelayDiffEq, DiffEqBase, Test

struct ODEFunctionWrapper{iip,F,H} <: DiffEqBase.AbstractODEFunction{iip}
    f::F
    h::H
end

function wrap(prob::DDEProblem)
    ODEFunctionWrapper{isinplace(prob.f),typeof(prob.f),typeof(prob.h)}(prob.f, prob.h)
end

(f::ODEFunctionWrapper{false})(u, p, t) = f.f(u, f.h, p, t)
(f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)

struct TestStruct{F,A}
    f::F
    a::A
end

function buildTestStruct(prob::DDEProblem, u, p, t)
    f = wrap(prob)
    a = f(u, p, t)

    TestStruct(f, a)
end

function buildTestStruct(prob::DDEProblem, du, u, p, t)
    f = wrap(prob)
    f(du, u, p, t)

    TestStruct(f, first(du))
end

function Base.getproperty(test::TestStruct, x::Symbol)
    if x === :f
        f = getfield(test, :f)
        if isinplace(f)
            (du, u, p, t) -> f.f(du, u, (p, t) -> [t * test.a], p, t)
        else
            (u, p, t) -> f.f(u, (p, t) -> t * test.a, p, t)
        end
    else
        getfield(test, x)
    end
end

function calc(test::TestStruct, u, p, t)
    f = test.f
    f(u, p, t)
end

function calc!(test::TestStruct, du, u, p, t)
    f = test.f
    f(du, u, p, t)
    nothing
end

function f_ip(du, u, h, p, t)
    du[1] = h(p, t)[1] - u[1]
    nothing
end

f_scalar(u, h, p, t) = h(p, t) - u

function test()
    prob_ip = DDEProblem(f_ip, [1.0], (p, t) -> [0.0], (0.0, 10.0))
    prob_scalar = DDEProblem(f_scalar, 1.0, (p, t) -> 0.0, (0.0, 10.0))

    wrap_ip = wrap(prob_ip)
    wrap_scalar = wrap(prob_scalar)

    a = [0.0]
    wrap_ip(a, [5.0], nothing, 0.0)
    @test a[1] == - 5.0
    wrap_ip(a, [5.0], nothing, 5.0)
    @test a[1] == - 5.0
    wrap_ip(a, [5.0], nothing, 10.0)
    @test a[1] == - 5.0

    @test wrap_scalar(5.0, nothing, 0.0) == - 5.0
    @test wrap_scalar(5.0, nothing, 5.0) == - 5.0
    @test wrap_scalar(5.0, nothing, 10.0) == - 5.0

    struct_ip = buildTestStruct(prob_ip, [0.0], [5.0], nothing, 4.0)
    @test struct_ip.a == -5.0

    struct_scalar = buildTestStruct(prob_scalar, 5.0, nothing, 4.0)
    @test struct_scalar.a == -5.0

    b = [0.0]
    calc!(struct_ip, b, [5.0], nothing, 1.0)
    @test b[1] == -10.0
    calc!(struct_ip, b, [5.0], nothing, 4.0)
    @test b[1] == -25.0

    @test calc(struct_scalar, 5.0, nothing, 2.0) == -15.0
    @test calc(struct_scalar, 5.0, nothing, 6.0) == -35.0
end

However, I'm not sure, how this will affect performance if it is possible at all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants