Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster algorithm for weighted sampling with replacement when k < n by reservoir sampling? #928

Closed
Tortar opened this issue Apr 18, 2024 · 1 comment

Comments

@Tortar
Copy link

Tortar commented Apr 18, 2024

I designed some time ago an algorithm (described in https://arxiv.org/abs/2403.20256) which I thought to be useful in sampling from data streams, it turns out that this is faster than the current algorithm in StatsBase by quite a bit, this surely needs some careful inspection but it is a lot faster in some cases when the number of items in the sample is less than the number of items in the population. This is it:

using Random, StatsBase, Distributions

function weighted_reservoir_sample(rng, a, ws, n)
    m = min(length(a), n)
    view_w_f_n = @view ws[1:m]
    w_sum = sum(view_w_f_n)
    reservoir = sample(rng, (@view a[1:m]), Weights(view_w_f_n, w_sum), n)
    length(a) <= n && return reservoir
    w_skip = skip(rng, w_sum, n)
    @inbounds for i in n+1:length(a)
        w_el = ws[i]
        w_sum += w_el
        if w_sum > w_skip
            p = w_el/w_sum
            z = (1-p)^(n-3)
            q = rand(rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
            k = choose(n, p, q, z)
            for j in 1:k
                r = rand(rng, j:n)
                reservoir[r] = a[i]
                reservoir[r], reservoir[j] = reservoir[j], reservoir[r]
            end 
            w_skip = skip(rng, w_sum, n)
        end
    end
    return shuffle!(rng, reservoir)
end

function skip(rng, w_sum::AbstractFloat, m)
    q = rand(rng)^(1/m)
    return w_sum/q
end

function choose(n, p, q, z)
    m = 1-p
    s = z
    z = s*m*m*(m + n*p)
    z > q && return 1
    z += n*p*(n-1)*p*s*m/2
    z > q && return 2
    z += n*p*(n-1)*p*(n-2)*p*s/6
    z > q && return 3
    return quantile(Binomial(n, p), q)
end

benchmarking

rng = Xoshiro(42);
a = collect(1:10^7);
wv(el) = rand() < 0.1 ? 10 * rand() : rand()
ws = Weights(wv.(a));

weighted_reservoir_sample(rng, a, ws, 1);
weighted_reservoir_sample(rng, a, ws, 10^4);
sample(rng, a, ws, 1);
sample(rng, a, ws, 10^4);

for i in 0:7
	t1 = @elapsed weighted_reservoir_sample(rng, a, ws, 10^i);
	t2 = @elapsed sample(rng, a, ws, 10^i);
	println("sample with 10^$i items with population of 10^7 items: $(t2/t1)")
end

shows this relative perf improvement in respect to the current one:

sample with 10^0 items with population of 10^7 items: 0.8358660101066833
sample with 10^1 items with population of 10^7 items: 5.248531783569411
sample with 10^2 items with population of 10^7 items: 19.3146914281279
sample with 10^3 items with population of 10^7 items: 17.139903421544233
sample with 10^4 items with population of 10^7 items: 10.72330908054339
sample with 10^5 items with population of 10^7 items: 3.2609968862521956
sample with 10^6 items with population of 10^7 items: 0.8949282382149918
sample with 10^7 items with population of 10^7 items: 0.9909354681929494

actually on the dev version after #927 it is even more pronounced:

sample with 10^0 items with population of 10^7 items: 0.5206986178221531
sample with 10^1 items with population of 10^7 items: 4.426968640113709
sample with 10^2 items with population of 10^7 items: 33.29267938056488
sample with 10^3 items with population of 10^7 items: 31.532953274019665
sample with 10^4 items with population of 10^7 items: 23.297611639617777
sample with 10^5 items with population of 10^7 items: 6.736645208235632
sample with 10^6 items with population of 10^7 items: 1.2902524581091546
sample with 10^7 items with population of 10^7 items: 0.9597508690177733

FWIW, this passes all my tests in https://github.com/JuliaDynamics/StreamSampling.jl which try to also assess if the sample is really random. What do you think of using this method in the cases it is faster?

@Tortar
Copy link
Author

Tortar commented Apr 22, 2024

It's probably necessary to publish it in a peer-review journal before anything else

@Tortar Tortar closed this as completed Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant