Skip to content

Commit

Permalink
add kernel weight for manellic products (#281)
Browse files Browse the repository at this point in the history
* add kernel weight for manellic products

* fix manellic tests, towards multiscale

* fix tests on manellic

* rm unused code
  • Loading branch information
dehann committed Apr 20, 2024
1 parent 2a1b07d commit 696d29c
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/entities/KernelEval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ abstract type AbstractKernel end
# TDB might already be covered in p.Σ.chol but having issues with SymPD (not particular to this AMP repo)
""" Manually maintained square root concentration matrix for faster compute, TODO likely duplicate of existing Distrubtions.jl functionality. """
sqrt_iΣ::iM = sqrt(inv(p.Σ))
""" Nonparametric weight value """
weight::Float64 = 1.0
end
3 changes: 3 additions & 0 deletions src/entities/ManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@ struct ManellicTree{M,D<:AbstractVector,N,HL,HT}
segments::SizedVector{N,Set{Int}}
left_idx::MVector{N,Int}
right_idx::MVector{N,Int}

# workaround to overcome bug for StaticArrays `isdefined() != false` issue
_workaround_isdef_treekernel::Set{Int}
end

8 changes: 6 additions & 2 deletions src/services/KernelEval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ end
# FIXME, REMOVE TYPE DISPLACEMENT
Base.eltype(mt::MvNormalKernel) = eltype(mt.p)

function MvNormalKernel::AbstractVector::AbstractArray)
function MvNormalKernel(
μ::AbstractVector,
Σ::AbstractArray,
weight::Real=1.0
)
_c = projectSymPosDef(Σ)
p=MvNormal(_c)
# NOTE, TBD, why not sqrt(inv(p.Σ)), this had an issue seemingly internal to PDMat.chol which breaks an already forced SymPD matrix to again be not SymPD???
sqrt_iΣ = sqrt(inv(_c))
MvNormalKernel(;μ, p, sqrt_iΣ)
MvNormalKernel(;μ, p, sqrt_iΣ, weight=float(weight))
end

Statistics.mean(m::MvNormalKernel) = m.μ # mean(m.p) # m.p.μ
Expand Down
109 changes: 91 additions & 18 deletions src/services/ManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,28 @@ getWeights(
) = view(mt.weights, mt.permute)


# either leaf or tree kernel, if larger than N
leftIndex(
::ManellicTree,
krnIdx::Int=1
) = 2*krnIdx
rightIndex(
::ManellicTree,
krnIdx::Int
) = 2*krnIdx+1


"""
$SIGNATURES
Return leaf kernel associated with input data element `i` (i.e. `permuted=true`).
Else when set to `permuted=false` return the sorted leaf_kernel `i` (different from unsorted input data number).
"""
getKernel(
getKernelLeaf(
mt::ManellicTree,
i::Int,
permuted::Bool = true
) = mt.leaf_kernels[mt.permute[i]]
) = mt.leaf_kernels[permuted ? mt.permute[i] : i]

uniWT(mt::ManellicTree) = 1===length(union(diff(getWeights(mt))))

Expand Down Expand Up @@ -138,7 +149,7 @@ function Base.convert(
μ = P(src.μ)
p = MvNormal(_matType(M)(cov(src.p)))
sqrt_iΣ = iM(src.sqrt_iΣ)
MvNormalKernel{P,T,M,iM}(μ, p, sqrt_iΣ)
MvNormalKernel{P,T,M,iM}(μ, p, sqrt_iΣ, src.weight)
end


Expand Down Expand Up @@ -179,8 +190,9 @@ DeVNotes:
"""
function splitPointsEigen(
M::AbstractManifold,
r_PP::AbstractVector{P};
kernel = MvNormal,
r_PP::AbstractVector{P},
weights::AbstractVector{<:Real} = ones(length(r_PP)); # FIXME, make static vector unless large
kernel = MvNormalKernel,
kernel_bw = nothing,
) where {P <: AbstractArray}
#
Expand Down Expand Up @@ -251,8 +263,10 @@ function splitPointsEigen(
_flipmask_minormax!(imask, mask, ax_CC1; argminmax=argmin)
_flipmask_minormax!(mask, imask, ax_CC1; argminmax=argmax)

weight = sum(weights)

# return rotated coordinates and split mask
ax_CCp, mask, kernel(p, cv)
ax_CCp, mask, kernel(p, cv, weight)
end


Expand Down Expand Up @@ -286,7 +300,7 @@ function buildTree_Manellic!(
# according to current index permutation (i.e. sort data as you build the tree)
ido = view(mtree.permute, idc)
# split the slice of order-permuted data
ax_CCp, mask, knl = splitPointsEigen(M, view(mtree.data, ido); kernel, kernel_bw=_kernel_bw)
ax_CCp, mask, knl = splitPointsEigen(M, view(mtree.data, ido), view(mtree.weights, ido); kernel, kernel_bw=_kernel_bw)
imask = xor.(mask, true)

# sort the data as 'small' and 'big' elements either side of the eigen split
Expand Down Expand Up @@ -319,6 +333,7 @@ function buildTree_Manellic!(
if index < N
_knl = convert(eltype(mtree.tree_kernels), knl)
mtree.tree_kernels[index] = _knl # HyperEllipse(knl.μ, knl.Σ.mat)
push!(mtree._workaround_isdef_treekernel, index)
mtree.segments[index] = Set(ido)
end

Expand Down Expand Up @@ -390,6 +405,7 @@ function buildTree_Manellic!(
SizedVector{len,Set{Int}}(undef),
MVector{len,Int}(undef),
MVector{len,Int}(undef),
Set{Int}()
)

#
Expand Down Expand Up @@ -551,41 +567,94 @@ end
## WIP Sequential Gibbs Product development


getKernelLeafAsTreeKer(mtr::ManellicTree{M,D,N,HL,HT}, idx::Int, permuted::Bool=true) where {M,D,N,HL,HT} = convert(HT,getKernelLeaf(mtr, idx % N, permuted))

function getKernelTree(
mtr::ManellicTree{M,D,N},
currIdx::Int,
) where {M,D,N}
# must return sorted given name signature "Tree"
permuted = true

if currIdx < N
return mtr.tree_kernels[currIdx]
else
return getKernelLeafAsTreeKer(mtr, currIdx, permuted)
end
end


# function getKernelsTreeLevelIdxs(
# mtr::ManellicTree{M,D,N},
# level::Int,
# currIdx::Int = 1
# ) where {M,D,N}

# # go to children if idx level too high
# _idxlevel(idx) = floor(Int,log2(idx)) + 1
# _exists(_i) = _i < N ? (_i in mtr._workaround_isdef_treekernel) : true # N from numPoints anyway # isdefined(mtr.leaf_kernels,_i)

# #traverse tree
# if level == _idxlevel(currIdx) && _exists(currIdx)
# return Int[currIdx;]
# end

# @warn "getKernelsTreeLevelIdxs still has corner case issues when tree_kernel[N]" maxlog=10

# # explore left and right
# return vcat(
# getKernelsTreeLevelIdxs(mtr,level,_getleft(currIdx)),
# getKernelsTreeLevelIdxs(mtr,level,_getright(currIdx))
# )
# end

# function getKernelsTreeLevel(
# mtr::ManellicTree{M,D,N},
# level::Int
# ) where {M,D,N}
# # kernels of that level
# levelIdxs = getKernelsTreeLevelIdxs(mtr,level)
# getKernelTree.(Ref(mtr),levelIdxs)
# end



function sampleProductSeqGibbsLabel(
M::AbstractManifold,
proposals::AbstractVector,
treeLevel::Int = 1, # reserved for future use
MC = 3,
)
#
# how many incoming proposals
d = length(proposals)

# # how many points per proposal
## TODO sample at multiscale levels on the tree, starting at treeLevel=1

# how many points per proposal at this depth level of the belief tree
Ns = proposals .|> getPoints .|> length

# TODO upgrade to multiscale
# start with random selection of labels
best_labels = [rand(1:n) for n in Ns]
latest_labels = [rand(1:n) for n in Ns]

# pick the next leave-out proposal
for _ in MC, O in 1:d
# select a label from the not-selected proposal densities
sublabels = Tuple{Int,Int}[]
for s in setdiff(1:d, O)
slbl = best_labels[s]
slbl = latest_labels[s]
push!(sublabels, (s,slbl))
end
# prop = proposals[s]

# TODO upgrade to map and tuples
components = map(sl->getKernel(proposals[sl[1]], sl[2], true), sublabels)
components = map(sl->getKernelLeaf(proposals[sl[1]], sl[2], true), sublabels)

# calc product of Gaussians from currently selected LOO-proposals
newO = calcProductGaussians(M, [components...])

# evaluate new sampling weights of points in out component
# NOTE getPoints returns the sorted (permuted) list of data
evat = getPoints(proposals[O]) # FIXME how should partials be handled here?
smw = zeros(length(evat))
# FIXME, use multipoint evaluation such as NN (not just one point at a time)
Expand All @@ -600,19 +669,22 @@ function sampleProductSeqGibbsLabel(
p = Categorical(smw)

# update label-distribution of out-proposal from product of selected LOO-proposal components
best_labels[O] = rand(p)
latest_labels[O] = rand(p)
end

# # recursively sample a layer deeper for each selected label, or fix that sample if that label is a leaf kernel
# for (i,l) in enumerate(latest_labels)

# end

#

return best_labels
return latest_labels
end


function sampleProductSeqGibbsLabels(
M::AbstractManifold,
proposals::AbstractVector,
treeLevel::Int = 1, # reserved for future use
MC = 3,
N::Int = round(Int, mean(length.(getPoints.(proposals))))
)
Expand All @@ -621,7 +693,7 @@ function sampleProductSeqGibbsLabels(
posterior_labels = Vector{NTuple{d,Int}}(undef,N)

for i in 1:N
posterior_labels[i] = tuple(sampleProductSeqGibbsLabel(M,proposals,treeLevel,MC)...)
posterior_labels[i] = tuple(sampleProductSeqGibbsLabel(M,proposals,MC)...)
end

posterior_labels
Expand All @@ -638,10 +710,11 @@ function calcProductKernelLabels(
post = []

for lbs in lbls
# FIXME different tree or leaf kernels would need different lists
props = MvNormalKernel[]
for (i,lb) in enumerate(lbs)
# selection of labels was done against sorted list of particles, hence false
push!(props, getKernel(proposals[i],lb,false))
push!(props, getKernelLeaf(proposals[i],lb,false))
end
push!(post,calcProductGaussians(M,props))
end
Expand Down
53 changes: 52 additions & 1 deletion test/manellic/testManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ M = TranslationGroup(2)
r_CC, R, pidx, r_CV = testEigenCoords(α);
ax_CCp, mask, knl = splitPointsEigen(M, r_CC)
@test sum(mask) == (length(r_CC) ÷ 2)
@test knl isa MvNormal
@test knl isa ApproxManifoldProducts.MvNormalKernel
Mr = SpecialOrthogonal(2)
@test isapprox( α, vee(Mr, Identity(Mr), log_lie(Mr, R))[1] ; atol=0.1)

Expand Down Expand Up @@ -113,6 +113,29 @@ end
@testset "ManellicTree construction 1D" begin
##

M = TranslationGroup(1)
# already sorted list
pts = [[1.],[2.],[4.],[7.],[11.],[16.],[22.]]
bw = [1.0]
mtree = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=bw,kernel=AMP.MvNormalKernel)

@test 7 == length( intersect( mtree.segments[1], Set(1:7)) )
@test 4 == length( intersect( mtree.segments[2], Set(1:4)) )
@test 3 == length( intersect( mtree.segments[3], Set(5:7)) )
@test 2 == length( intersect( mtree.segments[4], Set(1:2)) )
@test 2 == length( intersect( mtree.segments[5], Set(3:4)) )
@test 2 == length( intersect( mtree.segments[6], Set(5:6)) )

@test isapprox( mean(M,pts), mean(mtree.tree_kernels[1]); atol=1e-6)
@test isapprox( mean(M,pts[1:4]), mean(mtree.tree_kernels[2]); atol=1e-6)
@test isapprox( mean(M,pts[5:7]), mean(mtree.tree_kernels[3]); atol=1e-6)
@test isapprox( mean(M,pts[1:2]), mean(mtree.tree_kernels[4]); atol=1e-6)
@test isapprox( mean(M,pts[3:4]), mean(mtree.tree_kernels[5]); atol=1e-6)
@test isapprox( mean(M,pts[5:6]), mean(mtree.tree_kernels[6]); atol=1e-6)


## additional test datasets

function testMDEConstr(
pts::AbstractVector{<:AbstractVector{<:Real}},
permref = sortperm(pts, by=s->getindex(s,1));
Expand Down Expand Up @@ -183,6 +206,7 @@ for i in 1:10
testMDEConstr( _pts; lseg=1:4,rseg=5:8 )
end


##
end

Expand Down Expand Up @@ -457,6 +481,31 @@ g = ApproxManifoldProducts.calcProductGaussians(M, [g1; g2])
end


# @testset "Test utility functions for multi-scale product sampling" begin
# ##

# M = TranslationGroup(1)

# pts = [randn(1).-1 for _ in 1:3]
# p1 = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=[0.1;;], kernel=ApproxManifoldProducts.MvNormalKernel)

# @test 1 == length(ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 1))
# @test 2 == length(ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 2))
# @test 4 == length(ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 3))

# @test 64 == length(ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 7))
# @test 128 == length(ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 8))

# # @enter
# ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 2)
# ApproxManifoldProducts.getKernelsTreeLevelIdxs(p1, 3)



# ##
# end


@testset "Product of two Manellic beliefs, Sequential Gibbs" begin
##

Expand All @@ -468,6 +517,7 @@ p1 = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=[0.1;;], kerne
pts = [randn(1).+1 for _ in 1:128]
p2 = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw=[0.1;;], kernel=ApproxManifoldProducts.MvNormalKernel)

##

lbls = ApproxManifoldProducts.sampleProductSeqGibbsLabels(M, [p1; p2])

Expand All @@ -480,6 +530,7 @@ mtr = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw, kernel=Appro

@test isapprox( 0, mean(mtr.tree_kernels[1])[1]; atol=0.75)


##
end

Expand Down

0 comments on commit 696d29c

Please sign in to comment.