-
-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Training Strategies to dae_solvers.jl #838
base: master
Are you sure you want to change the base?
Conversation
Added in StochasticTraining strategy for generate_loss function
Added WeightedIntervalTraining
src/dae_solve.jl
Outdated
autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining.")) | ||
minT = tspan[1] | ||
maxT = tspan[2] | ||
|
||
weights = strategy.weights ./ sum(strategy.weights) | ||
|
||
N = length(weights) | ||
points = strategy.points | ||
|
||
difference = (maxT - minT) / N | ||
|
||
data = Float64[] | ||
for (index, item) in enumerate(weights) | ||
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ | ||
((index - 1) * difference) | ||
data = append!(data, temp_data) | ||
end | ||
|
||
ts = data | ||
function loss(θ, _) | ||
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) | ||
end | ||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the formatting is off here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will look into it.
src/dae_solve.jl
Outdated
function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, differential_vars::AbstractVector) | ||
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining.")) | ||
function loss(θ, _) | ||
ts = adapt(parameterless_type(θ), | ||
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) | ||
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) | ||
end | ||
return loss | ||
end | ||
|
||
|
||
function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,differential_vars::AbstractVector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a repeat of the ODE code. Is it actually necessary to repeat or could it be refactored so DAEs and ODEs use the same function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I did draw inspiration from the ODE code. I will look into refactoring it so that the same function will be used.
I added the test function for weighted interval training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have refactoed the generate_loss function for ode_solve. I just merged all the functions into one main function and added the conditional statements to carry out the different tasks based on the abstract type of the strategy.
No, that was not the issue. The issue was to refactor common computations in |
In the previous commit, I merged all the generate_loss function into one function, but I refactored the code incorrectly. I need to refactor each generate loss function between ode_solve and dae_solve
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the WeightedIntervalTraining Strategy in dae_solve.jl and the relevant test in NNDAE_test.jl, which the strategy passes. I would like to push that to be merged. My code has been formatted using JuliaFormatter too.
I think we should add rest of the strategies before merging. |
Added in StochasticTraining strategy for generate_loss function
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Add any other context about the problem here.