Skip to content

Commit

Permalink
Allow running findmap from a given starting point
Browse files Browse the repository at this point in the history
  • Loading branch information
sefffal committed Apr 24, 2024
1 parent f017e88 commit ce132b6
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions src/optimization.jl
Expand Up @@ -10,39 +10,49 @@ single row.
Returning a Chains object is a bit weird, but this way
it can be handled the same as our posteriors, plotted, etc.
"""
function findmap(model::LogDensityModel,N=100_000;verbosity=0)
θ′ = _findmap(model,N;verbosity)
logpost = model.ℓπcallback(model.link(θ′))
nt = (; logpost, model.arr2nt(θ′)...)
function findmap(model::LogDensityModel;starting_position=nothing,N=100_000,verbosity=0)
if isnothing(starting_position)
starting_position, _ = guess_starting_position(model.system,N)
end

logpost = model.ℓπcallback(model.link(starting_position))

θ_t′ = _findmap(model,model.link(starting_position);verbosity)
θ′ = model.invlink(θ_t′)
# Evaluate log post and log like
logpost = model.ℓπcallback(θ_t′)
resolved_namedtuple = model.arr2nt(θ′)
# Add log posterior, tree depth, and numerical error reported by
# the sampler.
# Also recompute the log-likelihood and add that too.
ln_like = make_ln_like(model.system, resolved_namedtuple)
loglike = ln_like(model.system, resolved_namedtuple)
nt = (; logpost, loglike, model.arr2nt(θ′)...)
return result2mcmcchain(
[nt],
Dict(:internals => [:logpost])
Dict(:internals => [:logpost, :loglike])
)
end

# Returns the raw parameter vector
function _findmap(model::LogDensityModel,N=100_000;verbosity=0)
function _findmap(model::LogDensityModel,initial_θ_t;verbosity=0)
func = OptimizationFunction(
(θ,model)->-model.ℓπcallback(θ),
grad=(G,θ,model)->G.=.-model.∇ℓπcallback(θ)[2],
)
verbosity > 1 && @info "Guessing starting position" N
θ0, _ = guess_starting_position(model.system,N)

# Start with Simulated Annealing
prob = OptimizationProblem(func, θ0, model)
# # Start with Simulated Annealing
prob = OptimizationProblem(func, initial_θ_t, model)
verbosity > 1 && @info "Simualted annealing optimization" N
sol = solve(prob, SimulatedAnnealing(), iterations=1_00_000, x_tol=0)
θ0 = sol.u

# Then iterate with qusi-Newton
prob = OptimizationProblem(func, sol.u, model)
verbosity > 1 && @info "LBFGS optimization" N
sol = solve(prob, LBFGS(), g_tol=1e-12, iterations=10000, allow_f_increases=true)
θ0 = sol.u
sol = solve(prob, LBFGS(), g_tol=1e-12, iterations=100000, allow_f_increases=true)
θ_map2 = sol.u

θ′ = model.invlink(θ0)
return θ′
return θ_map2

# logpost = model.ℓπcallback(model.link(θ′))
# if sol.retcode == ReturnCode.Success && isfinite(logpost)
Expand Down

0 comments on commit ce132b6

Please sign in to comment.