Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues with sampling #879

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
204 changes: 73 additions & 131 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,39 @@
using Base: mightalias

if isdefined(Base, :require_one_based_indexing) # TODO: use this directly once we require Julia 1.2+
using Base: require_one_based_indexing
else
require_one_based_indexing(xs...) =
any((!) ∘ isone ∘ firstindex, xs) && throw(ArgumentError("non 1-based arrays are not supported"))
end

function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, replace::Bool)
mightalias(input, output) &&
throw(ArgumentError("destination array must not share memory with the source array"))
require_one_based_indexing(input, output)
n = length(input)
k = length(output)
if !replace && k > n
throw(DimensionMismatch("cannot draw $k samples of $n values without replacement"))
ararslan marked this conversation as resolved.
Show resolved Hide resolved
end
return (n, k)
end

function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights,
output::AbstractArray, replace::Bool)
mightalias(output, weights) &&
throw(ArgumentError("destination array must not share memory with weights array"))
_validate_sample_inputs(input, weights)
return _validate_sample_inputs(input, output, replace)
end

function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights)
require_one_based_indexing(weights)
n = length(input)
nw = length(weights)
nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw"))
return n
end

###########################################################
#
Expand All @@ -10,16 +46,15 @@
### Algorithms for sampling with replacement

function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
s = Sampler(rng, 1:length(a))
n, k = _validate_sample_inputs(a, x, true)
s = Sampler(rng, 1:n)
b = a[1] - 1
if b == 0
for i = 1:length(x)
for i = 1:k
@inbounds x[i] = rand(rng, s)
end
else
for i = 1:length(x)
for i = 1:k
@inbounds x[i] = b + rand(rng, s)
end
end
Expand All @@ -36,12 +71,9 @@
This algorithm consumes `k` random numbers.
"""
function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
s = Sampler(rng, 1:length(a))
for i = 1:length(x)
n, k = _validate_sample_inputs(a, x, true)
s = Sampler(rng, 1:n)
for i = 1:k
ararslan marked this conversation as resolved.
Show resolved Hide resolved
@inbounds x[i] = a[rand(rng, s)]
end
return x
Expand All @@ -61,11 +93,7 @@

# order results of a sampler that does not order automatically
function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n, k = length(a), length(x)
n, k = _validate_sample_inputs(a, x, true)
# todo: if eltype(x) <: Real && eltype(a) <: Real,
# in some cases it might be faster to check
# issorted(a) to see if we can just sort x
Expand Down Expand Up @@ -140,13 +168,7 @@
"""
function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
initshuffle::Bool=true)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

# initialize
for i = 1:k
Expand Down Expand Up @@ -200,13 +222,7 @@
It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling
"""
function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

inds = Vector{Int}(undef, n)
for i = 1:n
Expand Down Expand Up @@ -240,13 +256,7 @@
drastically, resulting in poorer performance.
"""
function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

s = Set{Int}()
sizehint!(s, k)
Expand Down Expand Up @@ -282,13 +292,7 @@
The outputs are ordered.
"""
function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

i = 0
j = 0
Expand Down Expand Up @@ -324,13 +328,7 @@
The outputs are ordered.
"""
function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

i = 0
j = 0
Expand Down Expand Up @@ -370,13 +368,7 @@
The outputs are ordered.
"""
function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
N = length(a)
n = length(x)
n <= N || error("length(x) should not exceed length(a)")
N, n = _validate_sample_inputs(a, x, false)

i = 0
j = 0
Expand Down Expand Up @@ -485,10 +477,7 @@
"""
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
replace::Bool=true, ordered::Bool=false)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
k = length(x)
n, k = _validate_sample_inputs(a, x, replace)
k == 0 && return x

if replace # with replacement
Expand All @@ -499,8 +488,6 @@
end

else # without replacement
k <= n || error("Cannot draw more samples without replacement.")

if ordered
if n > 10 * k * k
seqsample_c!(rng, a, x)
Expand Down Expand Up @@ -582,8 +569,7 @@
(defaults to `Random.GLOBAL_RNG`).
"""
function sample(rng::AbstractRNG, wv::AbstractWeights)
1 == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
require_one_based_indexing(wv)
t = rand(rng) * sum(wv)
n = length(wv)
i = 1
Expand All @@ -596,7 +582,10 @@
end
sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv)

sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
function sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights)
_validate_sample_inputs(a, wv)
return a[sample(rng, wv)]

Check warning on line 587 in src/sampling.jl

View check run for this annotation

Codecov / codecov/patch

src/sampling.jl#L587

Added line #L587 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that this line isn't tested.

end
sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv)

"""
Expand All @@ -613,15 +602,8 @@
"""
function direct_sample!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
for i = 1:length(x)
_, k = _validate_sample_inputs(a, wv, x, true)
for i = 1:k
x[i] = a[sample(rng, wv)]
end
return x
Expand Down Expand Up @@ -702,14 +684,7 @@
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers.
"""
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
n, k = _validate_sample_inputs(a, wv, x, true)

# create alias table
ap = Vector{Float64}(undef, n)
Expand All @@ -718,7 +693,7 @@

# sampling
s = Sampler(rng, 1:n)
for i = 1:length(x)
for i = 1:k
j = rand(rng, s)
x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]]
end
Expand All @@ -740,15 +715,8 @@
"""
function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
k > 0 || return x

w = Vector{Float64}(undef, n)
copyto!(w, wv)
Expand Down Expand Up @@ -786,20 +754,13 @@
"""
function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function really part of the official API and needs checks of the arguments? IIRC I had never intended it to be called by any user directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if users use the (IMO) intended sample API then the arguments are already checked I assume.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't have thought it was intended to be user-facing at all except, as pointed out in #876, it's included in the manual. 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's not exported, is it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not, no. That said, there are implementations of three different Efraimidis-Spirakis algorithms (A, A-Res, and AExpJ), only one of which (AExpJ) is actually used internally by a function like sample. That suggests to me that there was the intention of use of these outside of the context sample but I could very well be mistaken as I don't know the history.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs were added in #254.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was actually summer 2016 and you reviewed it

LOL amazing. My brain runs the GC often so 7 years ago is long gone.

I definitely buy the argument that the separate, non-exported functions that each implement specific algorithms should not be considered user-facing and thus shouldn't need to perform the same kind of safety checks as those intended to be called directly. What gets me nervous is that there's nothing saying they aren't user-facing, hence issues like #876 and #877. Perhaps we could add admonitions to the docstrings, e.g.

!!! note
    This function is not intended to be called directly and is not considered
    part of the package's API.

?

A bit tangential to this discussion but in the future we could do something for sampling algorithms as is done for sorting algorithms in Base: each algorithm gets a type that subtypes some abstract sampling algorithm type then the user may select a particular algorithm via a keyword argument to sample, e.g. sample(x, wv; alg=EfraimidisAExpJ()), and internally that dispatches to use e.g. efraimidis_aexpj_wsample_norep! (after doing any appropriate argument checking 😄).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, passing types via an alg keyword argument would be the best API.

Better perform checks anyway, except if this means we run checks twice when called from sample. Is that the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except if this means we run checks twice when called from sample. Is that the case?

Currently yes. I can add a flag to the internal checking function that makes it a no-op if called from sample but perhaps that's more complex than necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How expensive are the checks? Is there a noticeable performance difference between calling sample and the internal function?

The alg keyword argument seems a reasonable suggestion for future refactorings.

For the time being I would prefer adding a warning or note to the docstrings of these internal functions. I think it was a mistake to add them to the docs at all (also based on the initial + follow-up PRs), so I would be fine even with just removing them from the docs. They're not exported and IMO have never been part of the official API (or at least they were not supposed to be).

k > 0 || return x

# calculate keys for all items
keys = randexp(rng, n)
for i in 1:n
@inbounds keys[i] = wv.values[i]/keys[i]
@inbounds keys[i] = wv[i]/keys[i]
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

# return items with largest keys
Expand Down Expand Up @@ -827,15 +788,7 @@
"""
function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
k > 0 || return x

# initialize priority queue
Expand All @@ -844,7 +797,7 @@
s = 0
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w = wv[s]
devmotion marked this conversation as resolved.
Show resolved Hide resolved
w < 0 && error("Negative weight found in weight vector at index $s")
if w > 0
i += 1
Expand All @@ -859,7 +812,7 @@
@inbounds threshold = pq[1].first

@inbounds for i in s+1:n
w = wv.values[i]
w = wv[i]
w < 0 && error("Negative weight found in weight vector at index $i")
w > 0 || continue
key = w/randexp(rng)
Expand Down Expand Up @@ -900,15 +853,7 @@
function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray;
ordered::Bool=false)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
k > 0 || return x

# initialize priority queue
Expand All @@ -917,7 +862,7 @@
s = 0
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w = wv[s]
w < 0 && error("Negative weight found in weight vector at index $s")
if w > 0
i += 1
Expand All @@ -933,7 +878,7 @@
X = threshold*randexp(rng)

@inbounds for i in s+1:n
w = wv.values[i]
w = wv[i]
w < 0 && error("Negative weight found in weight vector at index $i")
w > 0 || continue
X -= w
Expand Down Expand Up @@ -968,10 +913,8 @@

function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
replace::Bool=true, ordered::Bool=false)
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, replace)
k > 0 || return x

if replace
if ordered
Expand All @@ -991,7 +934,6 @@
end
end
else
k <= n || error("Cannot draw $k samples from $n samples without replacement.")
efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered)
end
return x
Expand Down