Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Median (and presumably all quantile computation) could be much faster for large inputs #154

Open
LilithHafner opened this issue Oct 26, 2023 · 1 comment

Comments

@LilithHafner
Copy link
Contributor

LilithHafner commented Oct 26, 2023

The concept is to take a random sample to quickly find values that almost certainly (99% chance) bracket target value(s), then efficiently pass over the whole input, counting values that fall above/below the bracketed range and explicitly storing only those that fall within the target range. If the median does not fall within the target range, try again with a new random seed up to three times (99.9999% success rate if the randomness is good). If the median does fall within the selected subset, find the exact target values within the selected subset.

Here's a naive implementation that is 4x faster for large inputs and allocates O(n ^ 2/3) memory instead of O(n) memory.

using Statistics
function my_median(v::AbstractVector)
    length(v) < 2^12 && return median(v)
    k = round(Int, length(v)^(1/3))
    lo_i = floor(Int, middle(1, k^2) - 1.3k)
    hi_i = ceil(Int, middle(1, k^2) + 1.3k)
    @assert 1 <= lo_i
    for _ in 1:3
        sample = rand(v, k^2)
        middle_of_sample = partialsort!(sample, lo_i:hi_i)
        lo_x, hi_x = first(middle_of_sample), last(middle_of_sample)
        number_below = 0
        middle_of_v = similar(v, 0)
        sizehint!(middle_of_v, 3k^2)
        for x in v
            a = x < lo_x
            b = x < hi_x
            number_below += Int(a)
            if a != b
                push!(middle_of_v, x)
            end
        end
        target = middle(firstindex(v), lastindex(v)) - number_below
        if isinteger(target)
            target_i = Int(target)
            checkbounds(Bool, middle_of_v, target_i) && return middle(partialsort!(middle_of_v, target_i))
        else
            target_lo = floor(Int, target)
            target_hi = ceil(Int, target)
            checkbounds(Bool, middle_of_v, target_lo:target_hi) && return middle(partialsort!(middle_of_v, target_lo:target_hi))
        end
    end
    median(v)
end

I think this is reasonably close to optimal for large inputs, but I payed no heed to optimizing the O(n^(2/3)) factors, so it is likely possible to optimize this to lower the crossover point where this becomes more efficient than the current median code.

This generalizes quite well to quantiles(n, k) for short k. It has a runtime of O(n * k) with a low constant factor. The calls to partialsort! can also be replaced with more efficient recursive calls to quantile

Benchmarks

Runtimes measured in clock cycles per element (@ 3.49 GHz)

length median my_median
10^1 16.01 30.84
10^2 15.74 40.28
10^3 14.52 17.47
10^4 9.87 8.67
10^5 8.77 5.29
10^6 11.15 3.67
10^7 14.53 3.11
10^8 13.06 2.71

10^9 OOMs.

Benchmark code
println("length | median | my_median")
println("-------|--------|----------")
for i in 1:8
    n = 10^i
    print("10^", rpad(i, 2), " | ")
    x = rand(n)
    t0 = @belapsed median($x)
    t0 *= 3.49e9/n
    print(rpad(round(t0, digits=2), 4, '0'), " | ")
    t1 = @belapsed my_median($x)
    t1 *= 3.49e9/n
    println(rpad(round(t1, digits=2), 4, '0'))
end

And I removed the length(x) < 2^12 fastpath to get accurate results for smaller inputs. I replaced the @assert with 1 <= lo_i || return median(v)

@nalimilan
Copy link
Member

Interesting. Why not use this at least for large vectors.

Regarding the performance of quantile, see also #91.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants