Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restart and store trace #1031

Open
sallyc1997 opened this issue Feb 28, 2023 · 2 comments
Open

Restart and store trace #1031

sallyc1997 opened this issue Feb 28, 2023 · 2 comments

Comments

@sallyc1997
Copy link

Hi! I'm using this package for my research that involves estimating a complicated model. I've been using the NelderMead() algorithm. From time to time, the estimation would stop due to memory issue. I have three questions:

  • I set store_trace=true. Could this be the cause of the memory issue?
  • Is it storing to a file? Where would the storing be? Or the trace will only be stored at the end of the execution?
  • Is there anyway I can store the trace and restart with that?

Thank you very much for your help!!

@pkofod
Copy link
Member

pkofod commented Aug 7, 2023

  • I set store_trace=true. Could this be the cause of the memory issue?

yes

  • Is it storing to a file? Where would the storing be? Or the trace will only be stored at the end of the execution?

no, it's not and you can't really do it

  • Is there anyway I can store the trace and restart with that?

you can only do it by running for a number of iterations and starting from that final point... sorry. Instead of using the trace system you can also store information in your objective function and control it that way. Then you could save it to a file and not worry about memory..

@KnutAM
Copy link

KnutAM commented Aug 7, 2023

Is there anyway I can store the trace and restart with that?

I hacked around to do this, in case it is helpful, but using some internals and serializing the state of the optimizer. Not sure how well it generalizes...

using Optim, UUIDs, Serialization

mutable struct SaveStateWrapper{F,S}
    const obj::F # Objective function 
    const optimstate::S
    const filename::String
    const num_calls_per_save::Int
    num_calls_since_save::Int
end
function SaveStateWrapper(obj, optimstate; num_calls_per_save)
    filename = string(uuid1())*".state"
    @info "Creating SaveStateWrapper with filename", filename
    return SaveStateWrapper(obj, optimstate, filename, num_calls_per_save, 0)
end

function save_optim_state(filename, state)
    tmp_file = filename*"_tmp"
    isfile(tmp_file) && rm(tmp_file)
    isfile(filename) && mv(filename, tmp_file)
    serialize(filename, state)
    isfile(tmp_file) && rm(tmp_file)
end

function (ssw::SaveStateWrapper)(args...; kwargs...)
    ssw.num_calls_since_save += 1
    if ssw.num_calls_since_save > ssw.num_calls_per_save
        save_optim_state(ssw.filename, ssw.optimstate)
        ssw.num_calls_since_save = 0
    end
    ssw.obj(args...; kwargs...)
end

function optimize_with_restart(obj, x0, method, options; 
        inplace = true, autodiff = :finite, # Optim settings
        num_calls_per_save=10,               # wrapper settings
        state=nothing
        )
    if state===nothing
        the_state = Optim.initial_state(method, options, Optim.promote_objtype(method, x0, autodiff, inplace, obj), x0)
    else
        the_state = state
    end
    wrapped_obj = SaveStateWrapper(obj, the_state; num_calls_per_save)
    real_obj = Optim.promote_objtype(method, x0, autodiff, inplace, wrapped_obj)
    return Optim.optimize(real_obj, x0, method, options, the_state)
end

The following test shows that it works. Must first get the output from the first run, and then save the *.state filename to test the restart.

# Create a special objective function around `sum(x)`, that will 
# 1) throw an error when I want to
# 2) record the history of objective values. 
struct MyObj
    vals::Vector{Float64}
    fail_at::Int
end
MyObj(;fail_at=10) = MyObj(Float64[], fail_at)
function (m::MyObj)(x)
    o = sum(x)
    push!(m.vals, o)
    length(m.vals) >= m.fail_at && error("Planned failure")
    return o
end

# o1 will fail after 10 function calls
o1 = MyObj()
try
    r = optimize_with_restart(o1, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1)    
catch e
    println(e) # Simulate failure
end

# Change to the path outputted during the creation of SaveStateWrapper
state_file = "de16fbc0-352e-11ee-3458-2b759117f9c2.state" 

# o2 will fail after 20 calls, but should restart from about where o1 left off
state = deserialize(state_file)
o2 = MyObj(;fail_at=20)
try
    r2 = optimize_with_restart(o2, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1, state=state)
catch e
    println(e)  # Simulate failure
end

offset = 3 # Not sure this isn't zero or 1...
restarted_trace = append!(copy(o1.vals), o2.vals[(1+offset):end])

# o3 runs from beginning without interruption (up to 30 function calls)
o3 = MyObj(;fail_at=30)
try
    r3 = optimize_with_restart(o3, ones(4), NelderMead(), Optim.Options(); num_calls_per_save=1)
catch e
    println(e)  # Simulate failure
end

for i in 1:length(restarted_trace)
    println("$i: ", o3.vals[i], ", ", restarted_trace[i], ". Same? ", o3.vals[i]restarted_trace[i])
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants