Skip to content

Commit

Permalink
improved manikde!_manellic construction (#285)
Browse files Browse the repository at this point in the history
* improved manikde!_manellic construction

* rm extraneous fields and code ManellicTree
  • Loading branch information
dehann committed Apr 30, 2024
1 parent 9250da2 commit 5b1a9ff
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 39 deletions.
6 changes: 3 additions & 3 deletions src/entities/ManellicTree.jl
Expand Up @@ -26,11 +26,11 @@ struct ManellicTree{M,D<:AbstractVector,N,HL,HT}
leaf_kernels::SizedVector{N,HL}
tree_kernels::SizedVector{N,HT}
segments::SizedVector{N,Set{Int}}
left_idx::MVector{N,Int}
right_idx::MVector{N,Int}
# left_idx::MVector{N,Int}
# right_idx::MVector{N,Int}

# workaround to overcome bug for StaticArrays `isdefined() != false` issue
_workaround_isdef_treekernel::Set{Int}
_workaround_isdef_leafkernel::Set{Int}
_workaround_isdef_treekernel::Set{Int}
end

4 changes: 1 addition & 3 deletions src/services/KernelEval.jl
Expand Up @@ -7,9 +7,6 @@ function projectSymPosDef(c::AbstractMatrix)
issymmetric(_c) ? _c : project(SymmetricPositiveDefinite(s[1]),_c,_c)
end

# FIXME ON FIRE, REMOVE TYPE DISPLACEMENT
Base.eltype(mt::MvNormalKernel) = eltype(mt.p)

function MvNormalKernel(
μ::AbstractVector,
Σ::AbstractArray,
Expand All @@ -26,6 +23,7 @@ Statistics.mean(m::MvNormalKernel) = m.μ # mean(m.p) # m.p.μ
Statistics.cov(m::MvNormalKernel) = cov(m.p) # note also about m.sqrt_iΣ
Statistics.std(m::MvNormalKernel) = sqrt(cov(m))

updateKernelBW(k::MvNormalKernel,_bw) = (p=MvNormal(_bw); MvNormalKernel(;μ=k.μ,p,weight=k.weight))

function evaluate(
M::AbstractManifold,
Expand Down
49 changes: 34 additions & 15 deletions src/services/ManellicTree.jl
Expand Up @@ -413,8 +413,6 @@ function buildTree_Manellic!(
npts = high - low + 1
mid_idx = low + sum(imask) - 1

# @info "BUILD" index low sum(mask) mid_idx high _getleft(index) _getright(index)

lft = mid_idx <= low ? low : leftIndex(mtree, index)
rgt = high <= mid_idx+1 ? high : rightIndex(mtree, index)

Expand All @@ -431,7 +429,7 @@ function buildTree_Manellic!(

if index < N
_knl = convert(eltype(mtree.tree_kernels), knl)
# FIXME use consolidate getKernelTree instead
# set tree kernel
mtree.tree_kernels[index] = _knl
push!(mtree._workaround_isdef_treekernel, index)
mtree.segments[index] = Set(ido)
Expand Down Expand Up @@ -495,14 +493,11 @@ function buildTree_Manellic!(
r_PP,
MVector{N,Float64}(weights),
MVector{N,Int}(1:N),
lkern, # MVector{N,lknlT}(undef),
lkern,
SizedVector{N,tknlT}(undef),
# SizedVector{N,tknlT}(undef),
SizedVector{N,Set{Int}}(undef),
MVector{N,Int}(undef),
MVector{N,Int}(undef),
_workaround_isdef_leafkernel,
Set{Int}(),
_workaround_isdef_leafkernel
)

#
Expand All @@ -517,15 +512,39 @@ function buildTree_Manellic!(

# manual reset leaves in the order discovered
permute!(tosort_leaves.leaf_kernels, tosort_leaves.permute)
# dupl = deepcopy(tosort_leaves.leaf_kernels)
# for (k,i) in enumerate(tosort_leaves.permute)
# tosort_leaves[i] = dupl.leaf_kernels[k]
# end

return tosort_leaves
end


function updateBandwidths(
mtr::ManellicTree{M,D,N,HL},
bws
) where {M,D,N,HL}
#
_getBW(s::Float64,::Int) = [s;;]
_getBW(s::AbstractVector{<:Real},::Int) = s
_getBW(s::AbstractMatrix{<:Real},::Int) = s
_getBW(s::AbstractVector{<:AbstractArray},_i::Int) = s[_i]

_leaf_kernels = SizedVector{N,HL}(undef)
for (i,lk) in enumerate(mtr.leaf_kernels)
_leaf_kernels[i] = updateKernelBW(lk,_getBW(bws,i))
end
ManellicTree(
mtr.manifold,
mtr.data,
mtr.weights,
mtr.permute,
_leaf_kernels,
mtr.tree_kernels,
mtr.segments,
mtr._workaround_isdef_leafkernel,
mtr._workaround_isdef_treekernel,
)
end


"""
$SIGNATURES
Expand Down Expand Up @@ -582,9 +601,9 @@ function evaluate(
pt,
LOO::Bool = false,
) where {M,D,N,HL}
# force function barrier, just to be sure dyndispatch is limited
_F() = getfield(ApproxManifoldProducts,HL.name.name)
_F_ = _F()
# # force function barrier, just to be sure dyndispatch is limited
# _F() = getfield(ApproxManifoldProducts,HL.name.name)
# _F_ = _F()

pts = getPoints(mt)
w = getWeights(mt)
Expand Down
34 changes: 16 additions & 18 deletions src/services/ManifoldKernelDensity.jl
Expand Up @@ -70,16 +70,18 @@ function ManifoldKernelDensity(
arr[:,j] = vee(M, ϵ, log(M, ϵ, vecP[j]))
end

manis = convert(Tuple, M)
# find or have the bandwidth
_bw = bw === nothing ? getKDEManifoldBandwidths(arr, manis ) : bw
# FIXME ON FIRE REMOVE LEGACY
manis = convert(Tuple, M)
# find or have the bandwidth
_bw = isnothing(bw) ? getKDEManifoldBandwidths(arr, manis ) : bw
# NOTE workaround for partials and user did not specify a bw
if bw === nothing && partial !== nothing
if isnothing(bw) && !isnothing(partial)
mask = ones(Int, length(_bw)) .== 1
mask[partial] .= false
_bw[mask] .= 1.0
end
addopT, diffopT, _, _ = buildHybridManifoldCallbacks(manis)
# FIXME ON FIRE REMOVE LEGACY
addopT, diffopT, _, _ = buildHybridManifoldCallbacks(manis)
bel = belmodel(arr,_bw,addopT,diffopT)
# bel = KernelDensityEstimate.kde!(arr, collect(_bw), addopT, diffopT)
return ManifoldKernelDensity(M, bel, partial, u0, infoPerCoord)
Expand All @@ -101,6 +103,7 @@ manikde!(

#


function manikde!_manellic(
M::AbstractManifold,
pts::AbstractVector;
Expand All @@ -118,7 +121,7 @@ function manikde!_manellic(
# Cost function to optimize
_cost(_pts, σ) = begin
# FIXME avoid rebuilding tree at each optim iteration!!!
mtr = buildTree_Manellic!(M, _pts; kernel_bw=reshape([σ;],manifold_dimension(M),1), kernel=MvNormalKernel)
mtr = buildTree_Manellic!(M, _pts; kernel_bw=reshape(σ,manifold_dimension(M),1), kernel=MvNormalKernel)
entropy(mtr)
end

Expand All @@ -127,27 +130,22 @@ function manikde!_manellic(
# set lower and upper bounds for Golden section optimization
lcov, ucov = getBandwidthSearchBounds(mtree)
res = Optim.optimize(
(s)->_cost(pts,s^2),
(s)->_cost(pts,[s^2;]),
lcov[1], ucov[1], Optim.GoldenSection()
)
best_cov = [Optim.minimizer(res);;]

best_cov = [Optim.minimizer(res);]

# reuse (heavy lift parts of) earlier tree build
# return tree with correct bandwidth
# TODO avoid tree rebuild somehow
manikde!(
M,
pts;
bw=best_cov,
belmodel = (a,b,aF,dF) -> ApproxManifoldProducts.buildTree_Manellic!(
M,
pts;
kernel_bw=b,
kernel=AMP.MvNormalKernel
)
belmodel = (ignore...) -> updateBandwidths(mtree, best_cov)
)
end



## ==========================================================================================
## a few utilities
## ==========================================================================================
Expand Down Expand Up @@ -305,7 +303,7 @@ function Base.show(io::IO, mkd::ManifoldKernelDensity{M,B,L,P}) where {M,B,L,P}
try
# mn = mean(mkd.manifold, getPoints(mkd, false))
mn = mean(mkd)
if mn isa ProductRepr
if mn isa ProductRepr # TODO UPDATE to ArrayPartition only, discontinued use of ProductRepr long ago.
println(io)
for prt in mn.parts
println(io, " ", round.(prt,digits=4))
Expand Down

0 comments on commit 5b1a9ff

Please sign in to comment.