Skip to content

Commit

Permalink
Use a faster and safer implementation of alias_sample! (#927)
Browse files Browse the repository at this point in the history
* Use a faster implementation for alias_sample!

* add invalid weights tests

* deprecate make_alias_table!
  • Loading branch information
LilithHafner committed Apr 12, 2024
1 parent 60fb5cd commit be809a8
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 73 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["JuliaStats"]
version = "0.34.3"

[deps]
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -17,6 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
AliasTables = "1"
DataAPI = "1"
DataStructures = "0.10, 0.11, 0.12, 0.13, 0.14, 0.17, 0.18"
LinearAlgebra = "<0.0.1, 1"
Expand Down
62 changes: 62 additions & 0 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,65 @@ end
@deprecate stdm(x::AbstractArray{<:Real}, w::AbstractWeights, m::AbstractArray{<:Real}, dim::Int; corrected::Union{Bool, Nothing}=nothing) std(x, w, dim, mean=m, corrected=corrected) false
@deprecate varm(x::AbstractArray{<:Real}, w::AbstractWeights, m::AbstractArray{<:Real}, dim::Int; corrected::Union{Bool, Nothing}=nothing) var(x, w, dim, mean=m, corrected=corrected) false
@deprecate varm!(R::AbstractArray, x::AbstractArray{<:Real}, w::AbstractWeights, m::AbstractArray{<:Real}, dim::Int; corrected::Union{Bool, Nothing}=nothing) var!(R, x, w, dim, mean=m, corrected=corrected) false

### This was never part of the public API
### Deprecated April 2024
function make_alias_table!(w::AbstractVector, wsum,
a::AbstractVector{Float64},
alias::AbstractVector{Int})
Base.depwarn("make_alias_table! is both internal and deprecated, use AliasTables.jl instead", :make_alias_table!)
# Arguments:
#
# w [in]: input weights
# wsum [in]: pre-computed sum(w)
#
# a [out]: acceptance probabilities
# alias [out]: alias table
#
# Note: a and w can be the same array, then that array will be
# overwritten inplace by acceptance probabilities
#
# Returns nothing
#

n = length(w)
length(a) == length(alias) == n ||
throw(DimensionMismatch("Inconsistent array lengths."))

ac = n / wsum
for i = 1:n
@inbounds a[i] = w[i] * ac
end

larges = Vector{Int}(undef, n)
smalls = Vector{Int}(undef, n)
kl = 0 # actual number of larges
ks = 0 # actual number of smalls

for i = 1:n
@inbounds ai = a[i]
if ai > 1.0
larges[kl+=1] = i # push to larges
elseif ai < 1.0
smalls[ks+=1] = i # push to smalls
end
end

while kl > 0 && ks > 0
s = smalls[ks]; ks -= 1 # pop from smalls
l = larges[kl]; kl -= 1 # pop from larges
@inbounds alias[s] = l
@inbounds al = a[l] = (a[l] - 1.0) + a[s]
if al > 1.0
larges[kl+=1] = l # push to larges
else
smalls[ks+=1] = l # push to smalls
end
end

# this loop should be redundant, except for rounding
for i = 1:ks
@inbounds a[smalls[i]] = 1.0
end
nothing
end
82 changes: 9 additions & 73 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
###########################################################

using AliasTables
using Random: Sampler

if VERSION < v"1.3.0-DEV.565"
Expand Down Expand Up @@ -635,65 +636,6 @@ end
direct_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
direct_sample!(default_rng(), a, wv, x)

function make_alias_table!(w::AbstractVector, wsum,
a::AbstractVector{Float64},
alias::AbstractVector{Int})
# Arguments:
#
# w [in]: input weights
# wsum [in]: pre-computed sum(w)
#
# a [out]: acceptance probabilities
# alias [out]: alias table
#
# Note: a and w can be the same array, then that array will be
# overwritten inplace by acceptance probabilities
#
# Returns nothing
#

n = length(w)
length(a) == length(alias) == n ||
throw(DimensionMismatch("Inconsistent array lengths."))

ac = n / wsum
for i = 1:n
@inbounds a[i] = w[i] * ac
end

larges = Vector{Int}(undef, n)
smalls = Vector{Int}(undef, n)
kl = 0 # actual number of larges
ks = 0 # actual number of smalls

for i = 1:n
@inbounds ai = a[i]
if ai > 1.0
larges[kl+=1] = i # push to larges
elseif ai < 1.0
smalls[ks+=1] = i # push to smalls
end
end

while kl > 0 && ks > 0
s = smalls[ks]; ks -= 1 # pop from smalls
l = larges[kl]; kl -= 1 # pop from larges
@inbounds alias[s] = l
@inbounds al = a[l] = (a[l] - 1.0) + a[s]
if al > 1.0
larges[kl+=1] = l # push to larges
else
smalls[ks+=1] = l # push to smalls
end
end

# this loop should be redundant, except for rounding
for i = 1:ks
@inbounds a[smalls[i]] = 1.0
end
nothing
end

"""
alias_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Expand All @@ -704,29 +646,23 @@ Build an alias table, and sample therefrom.
Reference: Walker, A. J. "An Efficient Method for Generating Discrete Random Variables
with General Distributions." *ACM Transactions on Mathematical Software* 3 (3): 253, 1977.
Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` time
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers.
Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n)`` time
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``k`` random numbers.
"""
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
1 == firstindex(a) == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
length(wv) == length(a) || throw(DimensionMismatch("Inconsistent lengths."))

# create alias table
ap = Vector{Float64}(undef, n)
alias = Vector{Int}(undef, n)
make_alias_table!(wv, sum(wv), ap, alias)
at = AliasTable(wv)

# sampling
s = Sampler(rng, 1:n)
for i = 1:length(x)
j = rand(rng, s)
x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]]
for i in eachindex(x)
j = rand(rng, at)
x[i] = a[j]
end
return x
end
Expand Down
3 changes: 3 additions & 0 deletions test/wsampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ for wv in (
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
end

@test_throws ArgumentError alias_sample!(rand(10), weights(fill(0, 10)), rand(10))
@test_throws ArgumentError alias_sample!(rand(100), weights(randn(100)), rand(10))

for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
r = rev ? reverse(4:7) : (4:7)
r = T===Int ? r : T.(r)
Expand Down

0 comments on commit be809a8

Please sign in to comment.