Skip to content

Commit

Permalink
MKD replacement towards #1010, and
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann committed Mar 7, 2021
1 parent 179f61f commit 1a66e57
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 82 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Expand Up @@ -40,7 +40,7 @@ TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
ApproxManifoldProducts = "0.2"
ApproxManifoldProducts = "0.2.2"
BSON = "0.2, 0.3"
Combinatorics = "1.0"
DataStructures = "0.16, 0.17, 0.18"
Expand Down
13 changes: 9 additions & 4 deletions src/AdditionalUtils.jl
Expand Up @@ -184,7 +184,8 @@ Build an approximate density `[Y|X,DX,.]=[X|Y,DX][DX|.]` as proposed by the cond
Notes
- Assume both are on circular manifold, `manikde!(pts, (:Circular,))`
"""
function approxConvCircular(pX::BallTreeDensity, pDX::BallTreeDensity; N::Int=100)
function approxConvCircular(pX::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
pDX::Union{<:BallTreeDensity,<:ManifoldKernelDensity}; N::Int=100)
#

# building basic factor graph
Expand All @@ -198,15 +199,19 @@ function approxConvCircular(pX::BallTreeDensity, pDX::BallTreeDensity; N::Int=10
approxConv(tfg,:s1s2f1,:s2)
end

function approxConvCircular(pX::BallTreeDensity, pDX::SamplableBelief; N::Int=100)
function approxConvCircular(pX::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
pDX::SamplableBelief; N::Int=100)
#
pts = reshape(rand(pDX, N), 1, :)
pC = manikde!(pts, Sphere1)
approxConvCircular(pX, pC)
end


function approxConvCircular(pX::SamplableBelief, pDX::BallTreeDensity; N::Int=100)
pts = reshape(rand(pX, N), 1, :)
function approxConvCircular(pX::SamplableBelief,
pDX::Union{<:BallTreeDensity,<:ManifoldKernelDensity}; N::Int=100)
#
pts = reshape(rand(pX, N), 1, :)
pC = manikde!(pts, Sphere1)
approxConvCircular(pC, pDX)
end
Expand Down
4 changes: 2 additions & 2 deletions src/BeliefTypes.jl
Expand Up @@ -18,7 +18,7 @@ struct NonparametricMessage <: MessageType end
struct ParametricMessage <: MessageType end


const SamplableBelief = Union{Distributions.Distribution, KDE.BallTreeDensity, AMP.ManifoldKernelDensity, AliasingScalarSampler, FluxModelsDistribution}
const SamplableBelief = Union{Distributions.Distribution, KDE.BallTreeDensity, AMP.ManifoldKernelDensity, AMP.ManifoldKernelDensity, AliasingScalarSampler, FluxModelsDistribution}

abstract type PackedSamplableBelief end

Expand Down Expand Up @@ -48,7 +48,7 @@ struct TreeBelief{T <: InferenceVariable}
# only populated during up as solvableDims for each variable in clique, #910
solvableDim::Float64
end
TreeBelief( p::BallTreeDensity,
TreeBelief( p::Union{<:BallTreeDensity, <:ManifoldKernelDensity},
inferdim::Real=0,
variableType::T=ContinuousScalar(),
manifolds=getManifolds(variableType),
Expand Down
38 changes: 23 additions & 15 deletions src/CompareUtils.jl
Expand Up @@ -8,9 +8,17 @@ import DistributedFactorGraphs: compare, compareAllSpecial
# the functions with IIF-specific parameters.
# To extend these, import the relevant DFG compareX function and overload it.

function compare(p1::BallTreeDensity, p2::BallTreeDensity)::Bool
return compareAll(p1.bt,p2.bt, skip=[:calcStatsHandle; :data]) &&
compareAll(p1,p2, skip=[:calcStatsHandle; :bt])
function compare( p1::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
p2::Union{<:BallTreeDensity,<:ManifoldKernelDensity} )
#
return compareAll(p1.bt,p2.bt, skip=[:calcStatsHandle; :data]) &&
compareAll(p1,p2, skip=[:calcStatsHandle; :bt])
end
function Base.isapprox( p1::Union{<:BallTreeDensity, <:ManifoldKernelDensity},
p2::Union{<:BallTreeDensity, <:ManifoldKernelDensity};
atol=1e-6)
#
mmd(p1,p2) < atol
end

function compareAllSpecial(A::T1, B::T2;
Expand All @@ -22,18 +30,18 @@ end



function compare(c1::TreeClique,
c2::TreeClique )
#
TP = true
TP = TP && c1.id == c2.id
# data
@warn "skipping ::TreeClique compare of data"
# TP = TP && compare(c1.data, c2.data)
function compare( c1::TreeClique,
c2::TreeClique )
#
TP = true
TP = TP && c1.id == c2.id
# data
@warn "skipping ::TreeClique compare of data"
# TP = TP && compare(c1.data, c2.data)

# attributes
@warn "only comparing keys of TreeClique attributes"
TP = TP && collect(keys(c1.attributes)) == collect(keys(c2.attributes))
# attributes
@warn "only comparing keys of TreeClique attributes"
TP = TP && collect(keys(c1.attributes)) == collect(keys(c2.attributes))

return TP
return TP
end
13 changes: 12 additions & 1 deletion src/Deprecated.jl
Expand Up @@ -61,6 +61,17 @@ mutable struct DebugCliqMCMC
end


##==============================================================================
## Deprecate as part of Manifolds.jl consolidation
##==============================================================================


# FIXME, much consolidation required here
Base.convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousScalar}) = AMP.Euclid
Base.convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousEuclid{1}}) = AMP.Euclid
Base.convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousEuclid{2}}) = AMP.Euclid2
Base.convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousEuclid{3}}) = AMP.Euclid3
Base.convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousEuclid{4}}) = AMP.Euclid4


##==============================================================================
Expand Down Expand Up @@ -235,7 +246,7 @@ DevNotes
- FIXME Integrate with `manifoldProduct`, see #1010
"""
function productpartials!(pGM::Array{Float64,2},
dummy::BallTreeDensity,
dummy::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
partials::Dict{Int, Vector{BallTreeDensity}},
manis::Tuple )
#
Expand Down
10 changes: 10 additions & 0 deletions src/DispatchPackedConversions.jl
@@ -1,4 +1,14 @@

struct PackedManifoldKernelDensity <: PackedSamplableBelief
json::String
end

Base.convert(::Type{<:SamplableBelief}, ::Type{<:PackedManifoldKernelDensity}) = ManifoldKernelDensity
Base.convert(::Type{<:PackedSamplableBelief}, ::Type{<:ManifoldKernelDensity}) = PackedManifoldKernelDensity
Base.convert(::Type{<:PackedSamplableBelief}, mkd::ManifoldKernelDensity) = convert(String, mkd)
Base.convert(::Type{<:SamplableBelief}, mkd::PackedManifoldKernelDensity) = convert(ManifoldKernelDensity, mkd.json)



function packmultihypo(fnc::CommonConvWrapper{T}) where {T<:FunctorInferenceType}
@warn "packmultihypo is deprecated in favor of Vector only operations"
Expand Down
28 changes: 18 additions & 10 deletions src/FGOSUtils.jl
Expand Up @@ -78,19 +78,31 @@ function clampBufferString(st::AbstractString, max::Int, len::Int=minimum([max,l
end


# export setSolvable!

manikde!(pts::AbstractArray{Float64,2}, vartype::Union{InstanceType{<:InferenceVariable}, InstanceType{<:FunctorInferenceType}}) = manikde!(pts, getManifolds(vartype))
manikde!(pts::AbstractArray{Float64,1}, vartype::Type{ContinuousScalar}) = manikde!(reshape(pts,1,:), getManifolds(vartype))

# extend convenience function
function manikde!(pts::AbstractArray{Float64,2},
bws::Vector{Float64},
variableType::Union{InstanceType{InferenceVariable}, InstanceType{FunctorInferenceType}} )
#
manikde!(pts, bws, getManifolds(variableType))
addopT, diffopT, getManiMu, getManiLam = buildHybridManifoldCallbacks(manifolds)
bel = KernelDensityEstimate.kde!(pts, bws, addopT, diffopT)
ampmani = convert(Manifold, variableType)
return ManifoldKernelDensity(ampmani, bel)
# manikde!(pts, bws, getManifolds(variableType))
end

function manikde!(pts::AbstractArray{Float64,2},
vartype::Union{InstanceType{<:InferenceVariable}, InstanceType{<:FunctorInferenceType}})
# = manikde!(pts, getManifolds(vartype))
#
addopT, diffopT, getManiMu, getManiLam = buildHybridManifoldCallbacks(getManifolds(vartype))
bel = KernelDensityEstimate.kde!(pts, addopT, diffopT)
ampmani = convert(Manifold, vartype)
return ManifoldKernelDensity(ampmani, bel)
end

manikde!(pts::AbstractArray{Float64,1}, vartype::Type{<:ContinuousScalar}) = manikde!(reshape(pts,1,:), vartype) #, getManifolds(vartype))



"""
$SIGNATURES
Expand Down Expand Up @@ -561,10 +573,6 @@ const setVariablePosteriorEstimates! = setPPE!
# Starting integration with Manifolds.jl, via ApproxManifoldProducts.jl first
## ============================================================================

# FIXME, much consolidation required here
convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousEuclid}) = AMP.Euclid
convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{ContinuousScalar}) = AMP.Euclid



"""
Expand Down
10 changes: 5 additions & 5 deletions src/FactorGraph.jl
Expand Up @@ -160,7 +160,7 @@ function setValKDE!(vd::VariableNodeData,
end

function setValKDE!(vd::VariableNodeData,
p::BallTreeDensity,
p::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
setinit::Bool=true,
inferdim::Union{Float32, Float64, Int32, Int64}=0 )
#
Expand Down Expand Up @@ -212,7 +212,7 @@ function setValKDE!(v::DFGVariable,
nothing
end
function setValKDE!(v::DFGVariable,
p::BallTreeDensity,
p::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
setinit::Bool=true,
inferdim::Union{Float32, Float64, Int32, Int64}=0;
solveKey::Symbol=:default )
Expand All @@ -222,7 +222,7 @@ function setValKDE!(v::DFGVariable,
end
function setValKDE!(dfg::G,
sym::Symbol,
p::BallTreeDensity,
p::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
setinit::Bool=true,
inferdim::Union{Float32, Float64, Int32, Int64}=0;
solveKey::Symbol=:default ) where G <: AbstractDFG
Expand Down Expand Up @@ -928,14 +928,14 @@ DevNotes
- TODO better document graphinit and treeinit.
"""
function initManual!( variable::DFGVariable,
ptsArr::BallTreeDensity)
ptsArr::Union{<:BallTreeDensity,<:ManifoldKernelDensity})
#
setValKDE!(variable, ptsArr, true)
return nothing
end
function initManual!( dfg::AbstractDFG,
label::Symbol,
belief::BallTreeDensity)
belief::Union{<:BallTreeDensity,<:ManifoldKernelDensity})
#
variable = getVariable(dfg, label)
initManual!(variable, belief)
Expand Down
8 changes: 6 additions & 2 deletions src/Factors/LinearRelative.jl
Expand Up @@ -26,7 +26,7 @@ end
LinearRelative(::UniformScaling=LinearAlgebra.I) = LinearRelative{1}(MvNormal(zeros(1), diagm(ones(1))))
LinearRelative(nm::Distributions.ContinuousUnivariateDistribution) = LinearRelative{1, typeof(nm)}(nm)
LinearRelative(nm::MvNormal) = LinearRelative{length(nm.μ), typeof(nm)}(nm)
LinearRelative(nm::BallTreeDensity) = LinearRelative{Ndim(nm), typeof(nm)}(nm)
LinearRelative(nm::Union{<:BallTreeDensity,<:ManifoldKernelDensity}) = LinearRelative{Ndim(nm), typeof(nm)}(nm)

getDimension(::InstanceType{LinearRelative{N,<:SamplableBelief}}) where {N} = N
getManifolds(::InstanceType{LinearRelative{N,<:SamplableBelief}}) where {N} = tuple([:Euclid for i in 1:N]...)
Expand All @@ -42,12 +42,16 @@ getSample(cf::CalcFactor{<:LinearRelative}, N::Int=1) = (reshape(rand(cf.factor.
# new and simplified interface for both nonparametric and parametric
function (s::CalcFactor{<:LinearRelative})(z, x1, x2)
# TODO convert to distance(distance(x2,x1),z) # or use dispatch on `-` -- what to do about `.-`
# v0.21+, should return residual
# v0.21+, should return residual
return z .- (x2 .- x1)
end



convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{LinearRelative{1}}) = AMP.Euclid
convert(::Type{<:ManifoldsBase.Manifold}, ::InstanceType{LinearRelative{2}}) = AMP.Euclid2



"""
$(TYPEDEF)
Expand Down
87 changes: 48 additions & 39 deletions src/SerializingDistributions.jl
Expand Up @@ -22,45 +22,6 @@ end
convert(::Type{<:SamplableBelief}, obj::PackedUniform) = return Uniform(obj.a, obj.b)



# TODO stop-gap string storage of Distrubtion types, should be upgraded to more efficient storage
function normalfromstring(str::AbstractString)
meanstr = match(r"μ=[+-]?([0-9]*[.])?[0-9]+", str).match
mean = split(meanstr, '=')[2]
sigmastr = match(r"σ=[+-]?([0-9]*[.])?[0-9]+", str).match
sigma = split(sigmastr, '=')[2]
Normal{Float64}(parse(Float64,mean), parse(Float64,sigma))
end

function mvnormalfromstring(str::AbstractString)
means = split(split(split(str, 'μ')[2],']')[1],'[')[end]
mean = Float64[]
for ms in split(means, ',')
push!(mean, parse(Float64, ms))
end
sigs = split(split(split(str, 'Σ')[2],']')[1],'[')[end]
sig = Float64[]
for ms in split(sigs, ';')
for m in split(ms, ' ')
length(m) > 0 ? push!(sig, parse(Float64, m)) : nothing
end
end
len = length(mean)
sigm = reshape(sig, len,len)
MvNormal(mean, sigm)
end

function categoricalfromstring(str::AbstractString)
# pstr = match(r"p=\[", str).match
psubs = split(str, '=')[end]
psubs = split(psubs, '[')[end]
psubsub = split(psubs, ']')[1]
pw = split(psubsub, ',')
p = parse.(Float64, pw)
return Categorical(p ./ sum(p))
end


# NOTE SEE EXAMPLE IN src/Flux/FluxModelsSerialization.jl
function _extractDistributionJson(jsonstr::AbstractString, checkJson::AbstractVector{<:AbstractString})
# Assume first word after split is the type
Expand Down Expand Up @@ -93,6 +54,8 @@ function convert(::Type{<:SamplableBelief}, str::Union{<:PackedSamplableBelief,<
# TODO this is the new direction for serializing (pack/unpack) of <:Samplable objects
# NOTE uses intermediate consolidation keyword search pattern `SamplableTypeJSON`
return _extractDistributionJson(str, checkJson)
elseif occursin(r"_type", str) && occursin(r"ManifoldKernelDensity", str)
return convert(ManifoldKernelDensity, str)
elseif startswith(str, "DiagNormal")
# Diags are internally squared, so only option here is to sqrt on input.
return mvnormalfromstring(str)
Expand Down Expand Up @@ -121,4 +84,50 @@ end




## DEPRECATE BELOW ========================================================================




# TODO stop-gap string storage of Distrubtion types, should be upgraded to more efficient storage
function normalfromstring(str::AbstractString)
meanstr = match(r"μ=[+-]?([0-9]*[.])?[0-9]+", str).match
mean = split(meanstr, '=')[2]
sigmastr = match(r"σ=[+-]?([0-9]*[.])?[0-9]+", str).match
sigma = split(sigmastr, '=')[2]
Normal{Float64}(parse(Float64,mean), parse(Float64,sigma))
end

function mvnormalfromstring(str::AbstractString)
means = split(split(split(str, 'μ')[2],']')[1],'[')[end]
mean = Float64[]
for ms in split(means, ',')
push!(mean, parse(Float64, ms))
end
sigs = split(split(split(str, 'Σ')[2],']')[1],'[')[end]
sig = Float64[]
for ms in split(sigs, ';')
for m in split(ms, ' ')
length(m) > 0 ? push!(sig, parse(Float64, m)) : nothing
end
end
len = length(mean)
sigm = reshape(sig, len,len)
MvNormal(mean, sigm)
end

function categoricalfromstring(str::AbstractString)
# pstr = match(r"p=\[", str).match
psubs = split(str, '=')[end]
psubs = split(psubs, '[')[end]
psubsub = split(psubs, ']')[1]
pw = split(psubsub, ',')
p = parse.(Float64, pw)
return Categorical(p ./ sum(p))
end




#
5 changes: 3 additions & 2 deletions src/SolverUtilities.jl
Expand Up @@ -33,8 +33,9 @@ function mmd( p1::AbstractMatrix{<:Real},
mmd(p1, p2, manis, bw=bw)
end

function mmd( p1::BallTreeDensity,
p2::BallTreeDensity,
# TODO move to AMP?
function mmd( p1::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
p2::Union{<:BallTreeDensity,<:ManifoldKernelDensity},
nodeType::Union{InstanceType{InferenceVariable},InstanceType{FunctorInferenceType}};
bw::AbstractVector{<:Real}=[0.001;])
#
Expand Down

0 comments on commit 1a66e57

Please sign in to comment.