Skip to content

Commit

Permalink
Merge pull request #39 from JuliaRobotics/feature/hybridmanifolds
Browse files Browse the repository at this point in the history
Feature/hybridmanifolds
  • Loading branch information
dehann committed Feb 6, 2019
2 parents 4172f54 + 5d55294 commit 30aeff2
Show file tree
Hide file tree
Showing 9 changed files with 601 additions and 376 deletions.
173 changes: 119 additions & 54 deletions src/BallTree01.jl
Expand Up @@ -63,7 +63,7 @@ getIndexOf(bt::BallTree, i::Int) = bt.permutation[i]
function swap!(data, _i::Int, _j::Int)
return data.swapHandle(data, _i, _j)
end
function calcStats!(data, root::Int, addop=+, diffop=-)
function calcStats!(data, root::Int, addop=(+,), diffop=(-,))
#@show "Fancy calcStats"
return data.calcStatsHandle(data, root , addop, diffop )
end
Expand Down Expand Up @@ -106,26 +106,29 @@ end

# Find the dimension along which the leaves between low and high
# inclusive have the greatest variance
function most_spread_coord(bt::BallTree, low::Int, high::Int, addop=+, diffop=-)
function most_spread_coord(bt::BallTree, low::Int, high::Int, addop=(+,), diffop=(-,))
#BallTree::index dimension, point, max_dim;
#double mean, variance, max_variance;
#println("most_spread_coord -- low, high = $((low, high))")
max_variance = 0
max_dim = 1
w = 1.0/(high-low)

for dimension = 1:bt.dims
for dimension in 1:bt.dims
# compute mean
mean = 0
for point = (bt.dims*(low-1) + dimension):bt.dims:(bt.dims*(high-1))
mean = addop(mean, bt.centers[point])
# scale each value by 1/N
mean = addop[dimension](mean, w*bt.centers[point])
end

# TODO: ensure this operation stays on-manifold for general cases
mean /= (high - low)

# now that we have the mean, compute variance
variance = 0
for point in (bt.dims*(low-1) + dimension):bt.dims:(bt.dims*(high-1))
variance += (diffop(bt.centers[point], mean))^2 # * (bt.centers[point] - mean);
variance += (diffop[dimension](bt.centers[point], mean))^2 # * (bt.centers[point] - mean);
end

# update variance if needed
if (variance > max_variance)
max_variance = variance;
max_dim = dimension;
Expand All @@ -136,33 +139,55 @@ function most_spread_coord(bt::BallTree, low::Int, high::Int, addop=+, diffop=-
return max_dim;
end

# straight from CLR, the unrandomized partition algorithm for
# quicksort. Partitions the leaves from low to high inclusive around
# """
# $SIGNATURES
#
# unrandomized partition algorithm for quicksort.
# Partitions the leaves from low to high inclusive around
# a random pivot in the given dimension. Does not affect non-leaf
# nodes, but does relabel the leaves from low to high.
function partition!(bt::BallTree, dimension::Int, low::Int, high::Int)
pivot = low; # not randomized, could set pivot to a random element

while (low < high)
while(bt.centers[bt.dims*(high-1) + dimension] >= bt.centers[bt.dims*(pivot-1) + dimension])
high-=1
end
while(bt.centers[bt.dims*(low-1) + dimension] < bt.centers[bt.dims*(pivot-1) + dimension])
low+=1
end

bt.swapHandle(bt, low, high)
pivot = high
end

return high;
end

#
# straight from CLR (? Intro to algorithms - Leierson, Rivest),
# """
# function partition!(bt::BallTree, dimension::Int, low::Int, high::Int, diffop)
# pivot = low; # not randomized, could set pivot to a random element
#
# while (low < high)
# while(bt.centers[bt.dims*(high-1) + dimension] >= bt.centers[bt.dims*(pivot-1) + dimension])
# high-=1
# end
# while(bt.centers[bt.dims*(low-1) + dimension] < bt.centers[bt.dims*(pivot-1) + dimension])
# low+=1
# end
#
# bt.swapHandle(bt, low, high)
# pivot = high
# end
#
# return high;
# end
# function partition!(bt::BallTree, dimension::Int, low::Int, high::Int, diffop)
# pivot = low; # not randomized, could set pivot to a random element
#
# while (low < high)
# while diffop(bt.centers[bt.dims*(high-1) + dimension], bt.centers[bt.dims*(pivot-1) + dimension]) >= 0.0
# high-=1
# end
# while diffop(bt.centers[bt.dims*(low-1) + dimension], bt.centers[bt.dims*(pivot-1) + dimension]) < 0.0
# low+=1
# end
#
# bt.swapHandle(bt, low, high)
# pivot = high
# end
#
# return high;
# end

# Function to partition the data into two (equal-sized or near as possible)
# sets, one of which is uniformly greater than the other in the given
# sets, one of which is uniformly "greater" than the other in the given
# dimension.
function select!(bt::BallTree, dimension::Int, position::Int, low::Int, high::Int)
function select!(bt::BallTree, dimension::Int, position::Int, low::Int, high::Int, diffop::T1) where {T1 <: Tuple}
m = 0
r = 0
i = 0
Expand All @@ -171,7 +196,7 @@ function select!(bt::BallTree, dimension::Int, position::Int, low::Int, high::In
swap!(bt.data, r, low)
m = low;
for i in (low):high
if (bt.centers[dimension+bt.dims*(i-1)] < bt.centers[dimension+bt.dims*(low-1)])
if diffop[dimension](bt.centers[dimension+bt.dims*(i-1)], bt.centers[dimension+bt.dims*(low-1)]) < 0.0
m+=1
swap!(bt.data, m, i);
end
Expand All @@ -183,10 +208,45 @@ function select!(bt::BallTree, dimension::Int, position::Int, low::Int, high::In
return nothing
end

"""
$SIGNATURES
Return "smaller" and "larger" of two child nodes.
"""
function getMiniMaxi(bt::BallTree,
leftI::Int,
rightI::Int,
d::Int,
addop,
diffop )::Tuple{Float64, Float64}
#

mini = 0.0
maxi = 0.0

# whos the most positive
a = addop[d]( center(bt, leftI, d), rangeB(bt, leftI, d) )
b = addop[d]( center(bt, rightI, d), rangeB(bt, rightI, d) )
if (a > b)
maxi = addop[d]( center(bt, leftI, d), rangeB(bt, leftI, d) )
else
maxi = addop[d]( center(bt, rightI, d), rangeB(bt, rightI, d) )
end

c = diffop[d](center(bt, leftI, d), rangeB(bt, leftI, d) )
c2 = diffop[d](center(bt, rightI, d), rangeB(bt, rightI, d))
if c < c2
mini = diffop[d]( center(bt, leftI, d), rangeB(bt, leftI, d) )
else
mini = diffop[d]( center(bt, rightI, d), rangeB(bt, rightI, d) )
end

return mini, maxi
end

# Calculate the statistics of level "root" based on the statistics of
# its left and right children.
function calcStatsBall!(bt::BallTree, root::Int, addop=+, diffop=-)
function calcStatsBall!(bt::BallTree, root::Int, addop=(+,), diffop=(-,))
#println("calcStatsBall! -- root=$(root)")
Ni = 0
NiL = 0
Expand All @@ -197,6 +257,8 @@ function calcStatsBall!(bt::BallTree, root::Int, addop=+, diffop=-)
leftI = left(bt, root)
rightI=right(bt, root)

# @show round.(bt.centers, digits=3)

# nothing to do if this isn't a parent node
if (!(validIndex(bt, leftI)) || !(validIndex(bt, rightI)))
return nothing
Expand All @@ -206,24 +268,24 @@ function calcStatsBall!(bt::BallTree, root::Int, addop=+, diffop=-)
maxi = 0.
mini = 0.
for d=1:bt.dims
a = addop( center(bt, leftI, d), rangeB(bt, leftI, d) )
b = addop( center(bt, rightI, d), rangeB(bt, rightI, d) )
#@show (d, leftI, rightI, a, b)
if (a > b)
maxi = addop( center(bt, leftI, d), rangeB(bt, leftI, d) )
else
maxi = addop( center(bt, rightI, d), rangeB(bt, rightI, d) )
end

if diffop(center(bt, leftI, d), rangeB(bt, leftI, d)) < diffop(center(bt, rightI, d), rangeB(bt, rightI, d))
mini = diffop( center(bt, leftI, d), rangeB(bt, leftI, d) )
else
mini = diffop( center(bt, rightI, d), rangeB(bt, rightI, d) )
end
# get which child is mini or maxi
mini, maxi = getMiniMaxi(bt, leftI, rightI, d, addop, diffop)

#@show (d, mini,maxi);
bt.centers[(root-1)*bt.dims+d] = addop(maxi, mini) / 2.0;
bt.ranges[(root-1)*bt.dims+d] = diffop(maxi, mini) / 2.0;
# @show (root-1)*bt.dims+d
# @show maxi, mini

# TODO implicit Euclidean comparison (not right!)
# computing the parent node halfspan
# bt.ranges[(root-1)*bt.dims+d] = diffop[d](maxi, mini) / 2.0; # Basic Euclidean
halfspan = diffop[d](maxi, mini) / 2.0; # Better on-manifold
bt.ranges[(root-1)*bt.dims+d] = halfspan

# Computing the parent node center
# @show "naive Euclidean mean", addop[d](maxi, mini) / 2.0
thecenter = addop[d](mini, halfspan)
bt.centers[(root-1)*bt.dims+d] = thecenter # addop[d](maxi, mini) / 2.0;
end

# if the left ball is the same as the right ball (should only
Expand All @@ -234,6 +296,9 @@ function calcStatsBall!(bt::BallTree, root::Int, addop=+, diffop=-)
else
bt.weights[root] = bt.weights[leftI]
end
# error("finishing with root=$root")


return nothing
end

Expand All @@ -245,8 +310,8 @@ function buildBall!(bt::BallTree,
low::Int,
high::Int,
root::Int,
addop=+,
diffop=- )
addop=(+,),
diffop=(-,) )::Nothing
global NO_CHILD
#println("buildBall! -- (low, high, root)=$((low, high, root))")
# special case for N=1 trees
Expand All @@ -272,7 +337,7 @@ function buildBall!(bt::BallTree,
# error).
split = (floor(Int,(low + high) / 2))
#@show coord, split, low, high
select!(bt, coord, split, low, high)
select!(bt, coord, split, low, high, diffop)

# an alternative is to use partition, but that doesn't deal well
# with repeated numbers and it doesn't split into balanced sets.
Expand Down Expand Up @@ -314,11 +379,11 @@ end

# Public method to build the tree, just calls the private method with
# the proper starting arguments.
function buildTree!(bt::BallTree, addop=+, diffop=-)
function buildTree!(bt::BallTree, addop=(+,), diffop=(-,))
global NO_CHILD
#println("buildTree!(::BallTree) -- is running")
i=bt.num_points
for j in 1:bt.num_points
@inbounds for j in 1:bt.num_points
for k in 1:bt.dims
bt.ranges[i*bt.dims+k] = 0
end
Expand All @@ -339,8 +404,8 @@ end
function makeBallTree(_pointsMatrix::Array{Float64,2},
_weights::Array{Float64,1},
suppressBuildTree=false,
addop=+,
diffop=- )
addop=(+,),
diffop=(-,) )
# get fields from input arguments
Nd = size(_pointsMatrix,1);
Np = size(_pointsMatrix,2);
Expand Down

0 comments on commit 30aeff2

Please sign in to comment.