Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme isn't ready for use with Bolt #59

Open
xzackli opened this issue Feb 12, 2022 · 13 comments · Fixed by SciML/SciMLSensitivity.jl#695
Open

Enzyme isn't ready for use with Bolt #59

xzackli opened this issue Feb 12, 2022 · 13 comments · Fixed by SciML/SciMLSensitivity.jl#695

Comments

@xzackli
Copy link
Owner

xzackli commented Feb 12, 2022

I played with Enzyme a little bit, and I suspect it's not ready for use with our package. It can't differentiate simple ODEs at present. There's a fair amount of linear algebra in OrdinaryDiffEq, so it's probably tripping up on BLAS.

# this will crash
using OrdinaryDiffEq
using Enzyme

f(u,p,t) = 1.01*u
function test(u0)
    tspan = (0.0,1.0)
    prob = ODEProblem(f,u0,tspan)
    sol = solve(prob,Rodas4(),reltol=1e-8,abstol=1e-8)
    return sol(1.0)
end

autodiff(test, Active(1.0)) # (g, 1.0)
@xzackli
Copy link
Owner Author

xzackli commented Feb 12, 2022

https://gist.github.com/xzackli/7c8819f3e7b43f16481e5d909c0b7764

@xzackli
Copy link
Owner Author

xzackli commented Feb 12, 2022

However, maybe iterative methods like #57 will be better? Nevertheless, our background and RECFAST are basically like the example above -- we need gradients through ODE solves.

@jmsull
Copy link
Collaborator

jmsull commented Feb 12, 2022

I am surprised such a simple example fails given the fact Enzyme was applied to ODEs previously? I guess they were not using Rodas4 so not hitting BLAS (or whatever the issue is)? Apparently, BLAS accounts for edge cases "99% of the time" here with Enzyme - haven't tried to read all the details here but we can leave Enzyme aside for now.

For the iterative methods we still need to solve the ode part with DE solvers so this problem is not going away.

@marius311
Copy link
Collaborator

Have you guys messed around with https://github.com/JuliaDiff/Diffractor.jl yet? I think that will also eventually have pretty well optimized scalar forward and reverse mode (scalar meaning should work well through loops / scalar indexing, unlike eg Zygote)

(sorry to interject I have major FOMO seeing you both do cool stuff here 😁 )

@xzackli
Copy link
Owner Author

xzackli commented Feb 13, 2022

@jmsull that's a really interesting thread! I'm glad to learn there are CS people at the julialab working on Enzyme BLAS support. Yeah, let's just leave this aside for a bit.

@marius311 I'm going to wait until their first tagged release, but it's exciting stuff! My understanding is that Diffractor needs compiler improvements from Julia 1.8, which makes it a bit harder to play with. At the very least, it will be nice to use a forward-mode AD package that supports ChainRules.

@marius311
Copy link
Collaborator

Yea, I think you'd want 1.8 (so Julia#master atm) but I figured I'd mention since Enzyme is still pretty early too. Fwiw I have played with it doing basic stuff and it is definitely working, but certainly not ready to actually depend on yet. Our of curiousity, is the ForwardDiff stuff already used here not good enough in some ways?

@jmsull
Copy link
Collaborator

jmsull commented Feb 13, 2022

@marius311 Interested to try it out then - we tried Enzyme even though it is early since the developers (and others) recommended it to us at the AD workshop - but happy to try out Diffractor as well if you say it's working.

The paper I linked above concludes that

"Our results show a strong performance advantage for automatic differentiation based discrete sensitivity analysis for forward-mode sensitivity analysis on sufficiently small systems, and an advantage for continuous adjoint sensitivity analysis for sufficiently large systems." (cf Fig. 2)

So we (or at least I) thought going to reverse mode might show performance gains since this is a large ODE system (at least with high ell_max).

@xzackli
Copy link
Owner Author

xzackli commented Feb 13, 2022

Just to add on to Jamie's comment, my understanding is that for n ODEs and p parameters, forward-mode AD will scale like O(np) whereas adjoint methods scale like O(n+p) but with a large overhead. arxiv:1812.01892 find that for small problems (like n + p < 50-100) the overhead of adjoint methods isn't worth it (also it shows Enzyme having some remarkable performance characteristics) and forward-mode AD still wins. Since the hierarchy tends to require fairly large systems to accurately solve for the transfer functions, this is a problem. Also, we have some ambition to involve ML so p can become large.

This does have some nice implications for the AD properties of the hierarchy-less method, since it reduces the system size so much.

@xzackli
Copy link
Owner Author

xzackli commented Feb 13, 2022

It occurs to me due to this discussion that perhaps people were saying that we should use Enzyme only for ODE internal vjps, i.e. in the sensitivity docs. We would use something like Zygote on the outside?

@xzackli
Copy link
Owner Author

xzackli commented Feb 14, 2022

Here's a somewhat realistic demo. Consider a future hierarchy-less situation where we have a small, stiff ODE system (+ some iterative methods) and some number of parameters we want to sample or optimize over. The rober example from the SciML docs isn't a bad placeholder.

using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, BenchmarkTools

function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p[1], p[2], p[3]
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2 + sum(p)
    nothing
end

function run_benchmarks()

    function sum_of_solution_fwd(x)
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(),reltol=1e-6,abstol=1e-6))
    end

    function sum_of_solution_CASA(x)
        sensealg = QuadratureAdjoint()  # change me, lots of choices here (arXiv:1812.01892)
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(),reltol=1e-6,abstol=1e-6,sensealg=sensealg))
    end

    u0 = [1.0,0.0,0.0]
    p = rand(256)  # change me, the number of parameters
    
    @btime ForwardDiff.gradient($sum_of_solution_fwd,[$u0;$p])
    @btime Zygote.gradient($sum_of_solution_CASA,[$u0; $p])

    nothing
end

run_benchmarks()
  38.065 ms (18168 allocations: 4.53 MiB)
  18.599 ms (117672 allocations: 21.47 MiB)

Note that the title of this issue still appears to be correct: using Enzyme for vjp still is broken.

sensealg = QuadratureAdjoint(autojacvec=EnzymeVJP())

@ChrisRackauckas
Copy link

Found this issue because of your talk. Here's the workaround for right now:

using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, BenchmarkTools

function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p[1], p[2], p[3]
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2 + sum(p)
    nothing
end

function run_benchmarks()

    function sum_of_solution_fwd(x)
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(),reltol=1e-6,abstol=1e-6))
    end

    function sum_of_solution_CASA(x)
        sensealg = QuadratureAdjoint(autojacvec=EnzymeVJP())
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(autodiff=false),reltol=1e-6,abstol=1e-6,sensealg=sensealg))
    end

    u0 = [1.0,0.0,0.0]
    p = rand(256)  # change me, the number of parameters

    @btime ForwardDiff.gradient($sum_of_solution_fwd,[$u0;$p])
    @btime Zygote.gradient($sum_of_solution_CASA,[$u0; $p])

    nothing
end

run_benchmarks()

# 11.490 ms (25068 allocations: 5.52 MiB)
# 2.956 ms (11024 allocations: 9.73 MiB)

The issue is mixing the forward-mode Jacobian for the nonlinear solver with the reverse-mode. This isn't too hard to fix I think, I'll make this into a test case.

@xzackli
Copy link
Owner Author

xzackli commented Jul 25, 2022

Thanks @ChrisRackauckas, this is exciting stuff -- I think there will be a substantial (order of magnitude?) performance improvement for us. We'll play with the workaround for now and watch for those changes to SciMLSensitivity.

I'd love for this example to make it into the tests. It's basically a cartoon for the problem we're trying to solve: stiff evolution being compared to data with a likelihood evaluation.

@ChrisRackauckas
Copy link

Fixed on SciMLSensitivity v7.2.0 and this is now a test. Let me know if you run into anything else. For reference the test set of mixing stiff solvers with adjoints is https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/stiff_adjoints.jl and just had a blind spot there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants