Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HamiltonianProblem not supported #53

Closed
adam-coogan opened this issue Jun 26, 2019 · 2 comments · Fixed by #370
Closed

HamiltonianProblem not supported #53

adam-coogan opened this issue Jun 26, 2019 · 2 comments · Fixed by #370
Labels

Comments

@adam-coogan
Copy link

Hi, thanks for the awesome package! Unfortunately I'm having issues using it for my work. I basically want to integrate a set of equations of motions defined by a Hamiltonian H(p, x) that is too complex to differentiate by hand. The code below sets up the problem:

using DifferentialEquations
using DiffEqFlux
using Flux

function H(p, x, params)  # complicated (randomly selected) Hamiltonian
    return params[1] * (1 + x[1]^2)^(1/3) * (1 + p[1]^2)^(1/4)
end

function getprob(params)  # set up problem
    tspan = (0.0, 1.0)
    x0, p0 = [0.01], [0.05]
    return HamiltonianProblem(H, p0, x0, tspan, params)
end

Now I want to differentiate solutions of the equations of motion with respect to the parameter vector. The function below sets this up and tries to run the ODE solver:

function testsolve(p0=[0.1])
    prob = getprob(p0)  # set up problem

    p = param(p0)  # initialize parameter

    function predict_rd()
        Tracker.collect(diffeq_rd(p, prob, Tsit5(), saveat=0.1))
    end
    
    println("test: ", predict_rd())  # try running the diff eq solver
end

When I run testsolve() I get the error

ERROR: MethodError: *(::Tracker.TrackedReal{Float64}, ::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqPhysics.PhysicsTag,Float64},Float64,1}) is ambiguous. Candidates:
  *(a::Tracker.TrackedReal, b::Real) in Tracker at /Users/acoogan/.julia/packages/Tracker/RRYy6/src/lib/real.jl:94
  *(x::Real, y::ForwardDiff.Dual{Ty,V,N} where N where V) where Ty in ForwardDiff at /Users/acoogan/.julia/packages/ForwardDiff/N0wMF/src/dual.jl:140
Possible fix, define
  *(::Tracker.TrackedReal, ::ForwardDiff.Dual{Ty,V,N} where N where V)

followed by a massive stack trace.

It seems like there's a conflict between the datatypes used to define the equations of motion and the parameters tracked by Flux. Is there a way to fix this?

I was not able to get this to work using Zygote rather than ForwardDiff to differentiate H -- hopefully I'm missing something about how to use nested autodifferentiation?

@ChrisRackauckas
Copy link
Member

As an update, we have certain mixings of ForwardDiff and Zygote working, like in JuliaDiff/SparseDiffTools.jl@26a3fc0 . You have to be careful with tags but then it works if you avoid JuliaLang/julia#265 related Zygote issues. The adjoint is almost setup to be using Zygote for the vjps SciML/SciMLSensitivity.jl#71 , so if that's the case Zygote adjoint may just fix this.

@ChrisRackauckas
Copy link
Member

If you want to manually work around it, you can do something like:

using DiffEqFlux,Flux,Zygote,Random,Plots,OrdinaryDiffEq,DiffEqSensitivity

Random.seed!(42)
# NN for Hamiltonian. Tiny for debugging
neural_H = Chain(
        Dense(2,1,tanh)
        )
p2,re = Flux.destructure(neural_H)
ps = Flux.params(p2)
function neural_hamiltonian!(du,u,p,t)
    H = Zygote.gradient(u -> re(p)(u)[1],u)[1]
    # Commented line below trains as normal
    #du[1] = re(p)(u)[1]
    du[1] = -H[1]
    du[2] = H[2]
end

tspan = (0.0f0,10.0f0)
dsize = 10
t = range(tspan[1],tspan[2],length=dsize)
u0 = Float32[1.0,1.0]
neural_ode_prob = ODEProblem(neural_hamiltonian!,u0,tspan,p2)
function predict_adjoint()
     Array(concrete_solve(neural_ode_prob,Tsit5(),u0,p2,saveat=t,sensealg=InterpolatingAdjoint(autojacvec=false)))
 end
# ode_data is the dataset I'm trying to fit.
ode_data = ones(2,10)
loss() = sum((ode_data .- predict_adjoint()).^2)
cb = function()
    #println(ps)
    display(loss())
end
dummydata = Iterators.repeated((),10)
opt = ADAM(0.03)
Flux.train!(loss,ps,dummydata,opt,cb=cb)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants