Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

marg_MAP option to use LBFGS hessians #48

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 96 additions & 9 deletions src/maximization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,33 @@ OptimKit._add!(η::Field, ξ::Field, β) = η .+= β .* ξ
Compute the maximum a posteriori (i.e. "MAP") estimate of the marginal posterior,
$\mathcal{P}(\phi,\theta\,|\,d)$.

Keyword arguments (same as MAP_joint unless otherwise stated)
* hess_method: "no-update"; "lbfgs-hessian" updates the Hessian with the
L-BFGS algorithm without the linesearch.
* α: scale the "no-update" approximate Hessian by this number.
All hess_methods use this as first step.
* nsteps_with_meanfield_update: number of steps that explicitly computes
the mean-field.

"""
MAP_marg(ds::DataSet; kwargs...) = MAP_marg(NamedTuple(), ds; kwargs...)
function MAP_marg(
θ,
ds :: DataSet;
nsteps = 10,
ϕstart = nothing,
ϕtol = nothing,
lbfgs_rank = 5,
Nϕ = :qe,
nsteps = 10,
nsteps_with_meanfield_update = 4,
conjgrad_kwargs = (tol=1e-1,nsteps=500),
α = 0.2,
hess_method="no-update",
weights = :unlensed,
Nsims = 50,
Nbatch = 1,
progress::Bool = true,
history_keys = (:ϕ),
aggressive_gc = fieldinfo(ds.d).Nx >=512 & fieldinfo(ds.d).Ny >=512
)

Expand All @@ -261,31 +273,106 @@ function MAP_marg(
Hϕ⁻¹ = (Nϕ == nothing) ? Cϕ : pinv(pinv(Cϕ) + pinv(Nϕ))

ϕ = (ϕstart != nothing) ? ϕstart : ϕ = zero(diag(Cϕ))
tr = []
history = []
state = nothing
lastϕ = nothing
lastg = nothing
lastHg = nothing
diffϕ = nothing
H = nothing
pbar = Progress(nsteps, (progress ? 0 : Inf), "MAP_marg: ")
ProgressMeter.update!(pbar)

for i=1:nsteps
niter = 1
while true
aggressive_gc && cuda_gc()
g, state = δlnP_δϕ(
ϕ, θ, ds;
use_previous_MF = i>nsteps_with_meanfield_update,
use_previous_MF = niter>nsteps_with_meanfield_update,
progress = false, return_state = true, previous_state = state,
Nsims, Nbatch, weights, conjgrad_kwargs, aggressive_gc
)
ϕ += T(α) * Hϕ⁻¹ * g
push!(tr, @dict(i,g,ϕ))
if (hess_method == "lbfgs-hessian" && isnothing(lastg));
H = initHessian(g,lbfgs_rank)
end

if (isnothing(lastg) || hess_method=="no-update")
ϕ += T(α)*Hϕ⁻¹*g
lastHg = deepcopy(T(α)*Hϕ⁻¹*g)
else
Hg, H = get_lbfgs_Hg(g, lastg, lastHg, H,
(_,η)->(Hϕ⁻¹*η))
ϕ += Hg
lastHg = deepcopy(Hg)
end
lastg = deepcopy(g)

if !isnothing(lastϕ)
diffϕ=sum(unbatch(norm(LowPass(1000) * (sqrt(ds.Cϕ) \ (ϕ - lastϕ))) / sqrt(2length(ϕ))))
end

push!(history, select((;g,ϕ,lastHg=Map(lastHg),diffϕ), history_keys))
Copy link
Collaborator Author

@kimmywu kimmywu Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marius311: For some unknown reason, converting to Map using Map(lastHg) is needed for the output lastHg to be scaled correctly when hess_method="lbfgs-hessian." Otherwise, it is orders of magnitude off (and looks like it's coming from a scale factor) when passed to history. It has the correct amplitude when applied to ϕ in the code.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't really remember tbh but glancing at

ϕ, = @⌛ optimize(
objective,
Map(ϕ),
OptimKit.LBFGS(
lbfgs_rank;
maxiter = nsteps,
verbosity = verbosity[1],
linesearch = OptimKit.HagerZhangLineSearch(verbosity=verbosity[2], maxiter=5)
);
finalize!,
inner = (_,ξ1,ξ2)->sum(unbatch(dot(ξ1,ξ2))),
precondition = (_,η)->Map(Hϕ⁻¹*η),
)
looks like I also have some Maps. I think in theory you could get rid of that by defining more of the things OptimKit needs like retract, inner (that one you already are), scale!, add!, and transport! as mentioned in their readme, although my guess is performance-wise the extra Map don't really matter so its probably fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did see that you pass in Map-converted field variables in MAP_joint for the optimization in places. I tested both and in my case, they yield the same results with or without passing a Map-converted field and figure to just go without to reduce the back-and-forth.

Agreed that it doesn't slow down the code. But it does mean it's storing a larger vector (Map vs Fourier), and it's not the same (Fourier)type as the rest of the return keys. So I want to see if you already know of similar peculiar behavior. Or if this is a corner case, that I run into.

next!(pbar, showvalues=[
("step",i),
("step",niter),
("Ncg (data)", length(state.gQD.history)),
("Ncg (sims)", i<=nsteps_with_meanfield_update ? length(first(state.gQD_sims).history) : "0 (MF not updated)"),
("Ncg (sims)", niter<=nsteps_with_meanfield_update ? length(first(state.gQD_sims).history) : "0 (MF not updated)"),
("α",α)
])

if ( (!isnothing(lastϕ) && !isnothing(ϕtol) && diffϕ < ϕtol)||
niter >= nsteps
)
break
else
lastϕ = deepcopy(ϕ)
niter += 1
end

end
ProgressMeter.finish!(pbar)

set_distributed_dataset(nothing) # free memory, which got used inside δlnP_δϕ

return ϕ, tr
return ϕ, history
end

function initHessian(g, lbfgs_rank)
#for g returned by δlnP_δϕ
TangentType = typeof(g)
ScalarType = typeof(sum(unbatch(dot(g,g))))

H = OptimKit.LBFGSInverseHessian(lbfgs_rank,
TangentType[], TangentType[], ScalarType[])
return H
end


function get_lbfgs_Hg(g::Field, gprev::Field, ηprev::Field,
H::OptimKit.LBFGSInverseHessian,
precondition; #(_,η)->(Hϕ⁻¹*η)
inner = (_,ξ1,ξ2)->sum(unbatch(dot(ξ1,ξ2))),
scale! = OptimKit._scale!,
add! = OptimKit._add!)
#following convention in OptimKit.LBFGS for
#minimal confusion

y = g .- gprev
s = ηprev

innersy = inner(0,s,y)
innerss = inner(0,s,s)

norms = sqrt(innerss)
ρ = innerss/innersy
OptimKit.push!(H, (scale!(s, 1/norms), scale!(y, 1/norms), ρ))

Hg = H(g, ξ->precondition(0, ξ), (ξ1, ξ2)->inner(0, ξ1, ξ2), add!, scale!)

η = scale!(Hg, -1)

return η, H
end