Skip to content

Commit

Permalink
Merge pull request #85 from CliMA/glw/perfect-catke-example
Browse files Browse the repository at this point in the history
Starts generalizing example utils for generic perfect model calibration
  • Loading branch information
navidcy committed Dec 6, 2021
2 parents 4658e4c + 09ef19a commit 871977f
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 109 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Documenter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- uses: julia-actions/setup-julia@latest
with:
version: 1.6
show-versioninfo: true
- name: Install dependencies
run: |
julia --color=yes --project -e 'using Pkg; Pkg.instantiate()'
Expand Down
85 changes: 47 additions & 38 deletions examples/intro_to_inverse_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using JLD2

examples_path = joinpath(pathof(OceanTurbulenceParameterEstimation), "..", "..", "examples")
include(joinpath(examples_path, "intro_to_observations.jl"))
data_path = generate_free_convection_synthetic_observations()
data_path = generate_synthetic_observations()
observations = OneDimensionalTimeSeries(data_path, field_names=:b, normalize=ZScore)

# # Building an "ensemble simulation"
Expand All @@ -45,6 +45,34 @@ observations = OneDimensionalTimeSeries(data_path, field_names=:b, normalize=ZSc
# simulation to find optimal parameters by minimizing the discrepency between
# the observations and the forward map.

"""
extract_perfect_parameters(observations, Nensemble)
Extract parameters from a batch of "perfect" observations.
"""
function extract_perfect_parameters(observations, Nensemble)
Nbatch = length(observations)
Qᵘ, Qᵇ, N², f = [zeros(Nensemble, Nbatch) for i = 1:4]

Nz = first(observations).grid.Nz
Hz = first(observations).grid.Hz
Lz = first(observations).grid.Lz
Δt = first(observations).metadata.parameters.Δt

for (j, obs) in enumerate(observations)
Qᵘ[:, j] .= obs.metadata.parameters.Qᵘ
Qᵇ[:, j] .= obs.metadata.parameters.Qᵇ
N²[:, j] .= obs.metadata.parameters.
f[:, j] .= obs.metadata.coriolis.f
end

file = jldopen(first(observations).path)
closure = file["serialized/closure"]
close(file)

return Qᵘ, Qᵇ, N², f, Δt, Lz, Nz, Hz, closure
end

"""
build_ensemble_simulation(observations; Nensemble=1)
Expand All @@ -53,72 +81,53 @@ ensemble of column models designed to reproduce `observations`.
"""
function build_ensemble_simulation(observations; Nensemble=1)

Nz = observations.grid.Nz
Hz = observations.grid.Hz
Lz = observations.grid.Lz
f₀ = observations.metadata.coriolis.f

file = jldopen(observations.path)

convective_κz = file["closure/convective_κz"]
background_κz = file["closure/background_κz"]
convective_νz = file["closure/convective_νz"]
background_νz = file["closure/background_νz"]

Δt = file["parameters"].Δt
observations isa Vector || (observations = [observations]) # Singleton batch
Nbatch = length(observations)

u_bcs = file["timeseries/u/serialized/boundary_conditions"]
b_bcs = file["timeseries/b/serialized/boundary_conditions"]
Qᵘ, Qᵇ, N², f, Δt, Lz, Nz, Hz, closure = extract_perfect_parameters(observations, Nensemble)

close(file)

column_ensemble_size = ColumnEnsembleSize(Nz=Nz, ensemble=(Nensemble, 1), Hz=Hz)
column_ensemble_size = ColumnEnsembleSize(Nz=Nz, ensemble=(Nensemble, Nbatch), Hz=Hz)
ensemble_grid = RectilinearGrid(size = column_ensemble_size, topology = (Flat, Flat, Bounded), z = (-Lz, 0))

ensemble_grid = RectilinearGrid(size = column_ensemble_size,
topology = (Flat, Flat, Bounded),
z = (-Lz, 0))
coriolis_ensemble = [FPlane(f=f[i, j]) for i = 1:Nensemble, j=1:Nbatch]
closure_ensemble = [deepcopy(closure) for i = 1:Nensemble, j=1:Nbatch]

closure = ConvectiveAdjustmentVerticalDiffusivity(; convective_κz, background_κz, convective_νz, background_νz)
u_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(Qᵘ))
b_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(Qᵇ), bottom = GradientBoundaryCondition(N²))

## Generate an ensemble of closures
Nex = ensemble_grid.Nx
Ney = ensemble_grid.Ny
tracers = first(observations).metadata.parameters.tracers

closure_ensemble = [deepcopy(closure) for i = 1:Nex, j = 1:Ney]

ensemble_model = HydrostaticFreeSurfaceModel(grid = ensemble_grid,
tracers = :b,
tracers = tracers,
buoyancy = BuoyancyTracer(),
boundary_conditions = (; u=u_bcs, b=b_bcs),
coriolis = FPlane(f=f₀),
coriolis = coriolis_ensemble,
closure = closure_ensemble)

ensemble_simulation = Simulation(ensemble_model; Δt=Δt, stop_time=observations.times[end])

optimal_parameters = (; convective_κz, background_κz, convective_νz, background_νz)
ensemble_simulation = Simulation(ensemble_model; Δt=Δt, stop_time=first(observations).times[end])

return ensemble_simulation, optimal_parameters
return ensemble_simulation, closure
end

# The following illustrations uses a simple ensemble simulation with two ensemble members:

ensemble_simulation, θ= build_ensemble_simulation(observations; Nensemble=3)
ensemble_simulation, closure= build_ensemble_simulation(observations; Nensemble=3)

# # Free parameters
#
# We construct some prior distributions for our free parameters. We found that it often helps to
# constrain the prior distributions so that neither very high nor very low values for diffusivities
# can be drawn out of the distribution.

priors = (convective_κz = lognormal_with_mean_std(0.3, 0.5),
priors = (convective_κz = lognormal_with_mean_std(0.3, 0.05),
background_κz = lognormal_with_mean_std(2.5e-4, 0.25e-4))

free_parameters = FreeParameters(priors)

# We also take the opportunity to collect a named tuple of the optimal parameters

θ★ = (convective_κz = θ.convective_κz,
background_κz = θ.background_κz)
θ★ = (convective_κz = closure.convective_κz,
background_κz = closure.background_κz)

# ## Visualizing the priors
#
Expand Down
62 changes: 25 additions & 37 deletions examples/intro_to_observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,48 @@ using CairoMakie
#
# We define a utility function for constructing synthetic observations,

function generate_free_convection_synthetic_observations(name = "convective_adjustment";
Nz = 32,
Lz = 64,
Qᵇ = +1e-8,
Qᵘ = -1e-5,
Δt = 10.0,
f₀ = 1e-4,
= 1e-6)
data_path = name * ".jld2"

if isfile(data_path)
return data_path
end
default_closure = ConvectiveAdjustmentVerticalDiffusivity(; convective_κz = 1.0,
convective_νz = 0.9,
background_κz = 1e-4,
background_νz = 1e-5)

convective_κz = 1.0
convective_νz = 0.9
background_κz = 1e-4
background_νz = 1e-5
function generate_synthetic_observations(name = "convective_adjustment"; Nz = 32, Lz = 64,
Qᵇ = +1e-8, Qᵘ = -1e-5, f₀ = 1e-4, N² = 1e-6,
Δt = 10.0, stop_time = 12hours,
tracers = :b, closure = default_closure)

grid = RectilinearGrid(size=32, z=(-64, 0), topology=(Flat, Flat, Bounded))
closure = ConvectiveAdjustmentVerticalDiffusivity(; convective_κz, background_κz, convective_νz, background_νz)
data_path = name * ".jld2"
isfile(data_path) && return data_path

grid = RectilinearGrid(size=Nz, z=(-Lz, 0), topology=(Flat, Flat, Bounded))
u_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(Qᵘ))
b_bcs = FieldBoundaryConditions(top = FluxBoundaryCondition(Qᵇ), bottom = GradientBoundaryCondition(N²))

model = HydrostaticFreeSurfaceModel(grid = grid,
tracers = :b,
buoyancy = BuoyancyTracer(),
boundary_conditions = (; u=u_bcs, b=b_bcs),
coriolis = FPlane(f=f₀),
closure = closure)

set!(model, b = (x, y, z) ->* z)

simulation = Simulation(model; Δt, stop_time=12hours)
model = HydrostaticFreeSurfaceModel(; grid, tracers, closure,
buoyancy = BuoyancyTracer(),
boundary_conditions = (; u=u_bcs, b=b_bcs),
coriolis = FPlane(f=f₀))

init_with_parameters(file, model) = file["parameters"] = (; Qᵇ, Qᵘ, Δt)
set!(model, b = (x, y, z) ->* z)
simulation = Simulation(model; Δt, stop_time)
init_with_parameters(file, model) = file["parameters"] = (; Qᵇ, Qᵘ, Δt, N², tracers=(:b, :e))

simulation.output_writers[:fields] = JLD2OutputWriter(model, merge(model.velocities, model.tracers),
schedule = TimeInterval(4hour),
schedule = TimeInterval(stop_time/3),
prefix = name,
array_type = Array{Float64},
field_slicer = nothing,
init = init_with_parameters,
force = true)

run!(simulation)

return data_path
end

# and invoke it:

data_path = generate_free_convection_synthetic_observations()
data_path = generate_synthetic_observations()

# # Specifying observations
#
Expand Down Expand Up @@ -111,8 +100,8 @@ observations = OneDimensionalTimeSeries(data_path, field_names=(:u, :v, :b), nor

fig = Figure()

ax_b = Axis(fig[1, 1], xlabel = "Buoyancy [m s⁻²]")
ax_u = Axis(fig[1, 2], xlabel = "Velocities [m s⁻¹]")
ax_b = Axis(fig[1, 1], xlabel = "Buoyancy [10⁻⁴ m s⁻²]", ylabel = "Depth [m]")
ax_u = Axis(fig[1, 2], xlabel = "Velocities [m s⁻¹]", ylabel = "Depth [m]")

z = znodes(Center, observations.grid)

Expand All @@ -128,7 +117,7 @@ for i = 1:length(observations.times)
u_label = i == 1 ? "u, " * label : label
v_label = i == 1 ? "v, " * label : label

lines!(ax_b, interior(b)[1, 1, :], z; label, color=colorcycle[i])
lines!(ax_b, 1e4 * interior(b)[1, 1, :], z; label, color=colorcycle[i]) # convert units from m s⁻² to 10⁻⁴ m s⁻²
lines!(ax_u, interior(u)[1, 1, :], z; linestyle=:solid, color=colorcycle[i], label=u_label)
lines!(ax_u, interior(v)[1, 1, :], z; linestyle=:dash, color=colorcycle[i], label=v_label)
end
Expand All @@ -143,4 +132,3 @@ save("intro_to_observations.svg", fig)
# Hint: if using a REPL or notebook, try
# `using Pkg; Pkg.add("ElectronDisplay"); using ElectronDisplay; display(fig)`
# To see the figure in a window.

4 changes: 2 additions & 2 deletions examples/perfect_baroclinic_adjustment_calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ xlims!(axmain, 350, 1350)
xlims!(axtop, 350, 1350)
ylims!(axmain, 650, 1750)
ylims!(axright, 650, 1750)
xlims!(axright, 0, 0.06)
ylims!(axtop, 0, 0.06)
xlims!(axright, 0, 0.025)
ylims!(axtop, 0, 0.025)

save("distributions_baroclinic_adjustment.svg", f); nothing #hide

Expand Down
138 changes: 138 additions & 0 deletions examples/perfect_catke_calibration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# # Perfect CAKTE calibration with Ensemble Kalman Inversion

# ## Install dependencies

# ```julia
# using Pkg
# pkg"add OceanTurbulenceParameterEstimation, Oceananigans, Distributions, CairoMakie"
# ```

using OceanTurbulenceParameterEstimation, LinearAlgebra, CairoMakie
using Oceananigans.TurbulenceClosures.CATKEVerticalDiffusivities: CATKEVerticalDiffusivity, MixingLength, SurfaceTKEFlux

examples_path = joinpath(pathof(OceanTurbulenceParameterEstimation), "..", "..", "examples")
include(joinpath(examples_path, "intro_to_inverse_problems.jl"))

mixing_length = MixingLength(Cᴬu=0.1, Cᴬc=0.1, Cᴬe=0.1, Cᴷuʳ=0.0, Cᴷcʳ=0.0, Cᴷeʳ=0.0)
catke = CATKEVerticalDiffusivity(mixing_length=mixing_length)
data_path = generate_synthetic_observations("catke", closure=catke, tracers=(:b, :e), Δt=10.0)
observations = OneDimensionalTimeSeries(data_path, field_names=(:u, :v, :b, :e), normalize=ZScore)

ensemble_simulation, closure★ = build_ensemble_simulation(observations; Nensemble=50)

priors = (Cᴷu⁻ = lognormal_with_mean_std(0.01, 0.1),
Cᴷc⁻ = lognormal_with_mean_std(0.01, 0.1),
Cᴷe⁻ = lognormal_with_mean_std(0.01, 0.1),
Cᴸᵇ = lognormal_with_mean_std(0.2, 0.1),
Cᴰ = lognormal_with_mean_std(1.0, 0.5),
CᵂwΔ = lognormal_with_mean_std(1.0, 0.2))

free_parameters = FreeParameters(priors)

calibration = InverseProblem(observations, ensemble_simulation, free_parameters)

# # Ensemble Kalman Inversion
#
# Next, we construct an `EnsembleKalmanInversion` (EKI) object,
#
# The calibration is done here using Ensemble Kalman Inversion. For more information about the
# algorithm refer to
# [EnsembleKalmanProcesses.jl documentation](https://clima.github.io/EnsembleKalmanProcesses.jl/stable/ensemble_kalman_inversion/).

noise_variance = observation_map_variance_across_time(calibration)[1, :, 1] .+ 1e-5

eki = EnsembleKalmanInversion(calibration; noise_covariance = Matrix(Diagonal(noise_variance)))

# and perform few iterations to see if we can converge to the true parameter values.

iterate!(eki; iterations = 10)

# Last, we visualize the outputs of EKI calibration.

θ̅(iteration) = [eki.iteration_summaries[iteration].ensemble_mean...]
varθ(iteration) = eki.iteration_summaries[iteration].ensemble_var

weight_distances = [norm(θ̅(iter) - [θ★[1], θ★[2]]) for iter in 1:eki.iteration]
output_distances = [norm(forward_map(calibration, θ̅(iter))[:, 1] - y) for iter in 1:eki.iteration]
ensemble_variances = [varθ(iter) for iter in 1:eki.iteration]

f = Figure()

lines(f[1, 1], 1:eki.iteration, weight_distances, color = :red, linewidth = 2,
axis = (title = "Parameter distance",
xlabel = "Iteration",
ylabel = "|θ̅ₙ - θ★|"))

lines(f[1, 2], 1:eki.iteration, output_distances, color = :blue, linewidth = 2,
axis = (title = "Output distance",
xlabel = "Iteration",
ylabel = "|G(θ̅ₙ) - y|"))

ax3 = Axis(f[2, 1:2],
title = "Parameter convergence",
xlabel = "Iteration",
ylabel = "Ensemble variance",
yscale = log10)

for (i, pname) in enumerate(free_parameters.names)
ev = getindex.(ensemble_variances, i)
lines!(ax3, 1:eki.iteration, ev / ev[1], label = String(pname), linewidth = 2)
end

axislegend(ax3, position = :rt)

save("summary_catke_eki.svg", f); nothing #hide

# ![](summary_catke_eki.svg)

# And also we plot the the distributions of the various model ensembles for few EKI iterations to see
# if and how well they converge to the true diffusivity values.

f = Figure()

axtop = Axis(f[1, 1])

axmain = Axis(f[2, 1],
xlabel = "Cᴷu⁻ [m² s⁻¹]",
ylabel = "Cᴷc⁻ [m² s⁻¹]")

axright = Axis(f[2, 2])
scatters = []

for iteration in [1, 2, 3, 11]
## Make parameter matrix
parameters = eki.iteration_summaries[iteration].parameters
Nensemble = length(parameters)
Nparameters = length(first(parameters))
parameter_ensemble_matrix = [parameters[i][j] for i=1:Nensemble, j=1:Nparameters]

push!(scatters, scatter!(axmain, parameter_ensemble_matrix))
density!(axtop, parameter_ensemble_matrix[:, 1])
density!(axright, parameter_ensemble_matrix[:, 2], direction = :y)
end

vlines!(axmain, [θ★.Cᴷu⁻], color = :red)
vlines!(axtop, [θ★.Cᴷu⁻], color = :red)

hlines!(axmain, [θ★.Cᴷc⁻], color = :red)
hlines!(axright, [θ★.Cᴷc⁻], color = :red)

colsize!(f.layout, 1, Fixed(300))
colsize!(f.layout, 2, Fixed(200))
rowsize!(f.layout, 1, Fixed(200))
rowsize!(f.layout, 2, Fixed(300))

Legend(f[1, 2], scatters, ["Initial ensemble", "Iteration 1", "Iteration 2", "Iteration 10"],
position = :lb)

hidedecorations!(axtop, grid = false)
hidedecorations!(axright, grid = false)

xlims!(axmain, -0.25, 3.2)
xlims!(axtop, -0.25, 3.2)
ylims!(axmain, 5e-5, 35e-5)
ylims!(axright, 5e-5, 35e-5)

save("distributions_catke_eki.svg", f); nothing #hide

# ![](distributions_catke_eki.svg)

0 comments on commit 871977f

Please sign in to comment.