Skip to content

Commit

Permalink
add dev manikde!_manellic for 1D only (#284)
Browse files Browse the repository at this point in the history
* add dev manikde!_manellic for 1D only

* test fix
  • Loading branch information
dehann committed Apr 30, 2024
1 parent 0f4b8db commit 9250da2
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 13 deletions.
12 changes: 7 additions & 5 deletions src/CommonUtils.jl
Expand Up @@ -123,10 +123,10 @@ DevNotes:
"""
function calcProductGaussians(
M::AbstractManifold,
μ_::Union{<:AbstractVector{P},<:NTuple{N,P}}, # point type commonly known as P
μ_::Union{<:AbstractVector{P},<:NTuple{N,P}}, # point type commonly known as P (actually on-manifold)
Σ_::Union{Nothing,<:AbstractVector{S},<:NTuple{N,S}};
dim::Integer=manifold_dimension(M),
Λ_ = inv.(Σ_),
Λ_ = inv.(Σ_), # TODO these probably need to be transported to common tangent space `u0` -- FYI @Affie 24Q2
) where {N,P,S<:AbstractMatrix{<:Real}}
#
# calc sum of covariances
Expand Down Expand Up @@ -176,15 +176,17 @@ calcProductGaussians(
EXPERIMENTAL: On-manifold product of Gaussians.
DevNotes
- CHECK make sure this product is properly on manifold, Manifolds.jl likely already has solutions:
- https://juliamanifolds.github.io/Manifolds.jl/stable/features/distributions.html
- FIXME is parallel transport needed when multiplying with covariances from difffent tangent spaces?
"""
function calcProductGaussians(
M::AbstractManifold,
comps::AbstractVector{<:MvNormalKernel},
)
#
μ_ = mean.(comps)
Σ_ = cov.(comps)
# CHECK this should be on-manifold for points
μ_ = mean.(comps) # This is a ArrayPartition which IS DEFINITELY ON MANIFOLD (we dispatch on mean)
Σ_ = cov.(comps) # on tangent

# FIXME is parallel transport needed here for covariances from different tangent spaces?

Expand Down
2 changes: 1 addition & 1 deletion src/services/KernelEval.jl
Expand Up @@ -30,7 +30,7 @@ Statistics.std(m::MvNormalKernel) = sqrt(cov(m))
function evaluate(
M::AbstractManifold,
ekr::MvNormalKernel,
p
p # on manifold point
)
#
dim = manifold_dimension(M)
Expand Down
7 changes: 4 additions & 3 deletions src/services/ManellicTree.jl
Expand Up @@ -98,7 +98,7 @@ function getKernelTree(
leafIdxs .+= N
bws = [cov(getKernelTree(mtr,lidx,false)) for lidx in leafIdxs]
# FIXME is a parallel transport needed between different kernel covariances that each exist in different tangent spaces
mean_bw = mean(bws)
mean_bw = mean(bws) # FIXME upgrade to on-manifold mean
# corrected cov varies from root (only Monte Carlo cov est) to leaves (only selected bandwdith)
nC = (1-λ)*cov(raw_ker) + λ*mean_bw
# return a new kernel with cov_continuation, of tree kernel type
Expand Down Expand Up @@ -573,6 +573,7 @@ DevNotes:
- use geometric computing for faster evaluation
- Dual tree evaluations
- Holmes, M.P., Gray, A.G. and Isbell Jr, C.L., 2010. Fast kernel conditional density estimation: A dual-tree Monte Carlo approach. Computational statistics & data analysis, 54(7), pp.1707-1718.
- Curtin, R., March, W., Ram, P., Anderson, D., Gray, A. and Isbell, C., 2013, May. Tree-independent dual-tree algorithms. In International Conference on Machine Learning (pp. 1435-1443). PMLR.
- Fast kernels
- Parallel transport shortcuts?
"""
Expand Down Expand Up @@ -768,6 +769,8 @@ Notes:
- Advise, 2<=MC to ensure multiscale works during decent transitions (TBD obsolete requirement)
- To force sequential Gibbs on leaves only, use:
`label_pools = [[(length(getPoints(prop))+1):(2*length(getPoints(prop)));] for prop in proposals]`
- Taken from: Sudderth, E.B., Ihler, A.T., Isard, M., Freeman, W.T. and Willsky, A.S., 2010.
Nonparametric belief propagation. Communications of the ACM, 53(10), pp.95-103.
"""
function sampleProductSeqGibbsBTLabel(
M::AbstractManifold,
Expand Down Expand Up @@ -802,8 +805,6 @@ function sampleProductSeqGibbsBTLabel(
# construct new label pool for children in multiscale
child_label_pools, all_leaves = generateLabelPoolRecursive(proposals, labels_sampled)

# @info "WHY STOP" child_label_pools all_leaves

# recursively call sampling down the multiscale tree ("pyramid") -- aka homotopy
# limit recursion to MAX_RECURSE_DEPTH
if 0<MAX_RECURSE_DEPTH && !all_leaves
Expand Down
46 changes: 46 additions & 0 deletions src/services/ManifoldKernelDensity.jl
Expand Up @@ -101,6 +101,52 @@ manikde!(

#

function manikde!_manellic(
M::AbstractManifold,
pts::AbstractVector;
bw=ones(manifold_dimension(M),1),
)
#

mtree = ApproxManifoldProducts.buildTree_Manellic!(
M,
pts;
kernel_bw=bw,
kernel=AMP.MvNormalKernel
)

# Cost function to optimize
_cost(_pts, σ) = begin
# FIXME avoid rebuilding tree at each optim iteration!!!
mtr = buildTree_Manellic!(M, _pts; kernel_bw=reshape([σ;],manifold_dimension(M),1), kernel=MvNormalKernel)
entropy(mtr)
end

# optimize for best LOOCV bandwidth
# FIXME switch to RLM (or other Manopt) techinque instead
# set lower and upper bounds for Golden section optimization
lcov, ucov = getBandwidthSearchBounds(mtree)
res = Optim.optimize(
(s)->_cost(pts,s^2),
lcov[1], ucov[1], Optim.GoldenSection()
)
best_cov = [Optim.minimizer(res);;]

# return tree with correct bandwidth
# TODO avoid tree rebuild somehow
manikde!(
M,
pts;
bw=best_cov,
belmodel = (a,b,aF,dF) -> ApproxManifoldProducts.buildTree_Manellic!(
M,
pts;
kernel_bw=b,
kernel=AMP.MvNormalKernel
)
)
end


## ==========================================================================================
## a few utilities
Expand Down
14 changes: 10 additions & 4 deletions test/manellic/testManellicTree.jl
Expand Up @@ -418,12 +418,18 @@ end


# TODO
# @testset "Manellic tree bandwidth optimize n-dim RLM" begin
# ##
@testset "Manellic tree all up construction with bandwith optimization" begin
##


# ##
# end
M = TranslationGroup(1)
# pts = [[0.;],[0.1],[0.2;],[0.3;]]
pts = [1*randn(1) for _ in 1:64]

mkd = ApproxManifoldProducts.manikde!_manellic(M,pts)

##
end



Expand Down

0 comments on commit 9250da2

Please sign in to comment.