Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional higher-level rle api? #897

Open
ericphanson opened this issue Oct 2, 2023 · 0 comments
Open

additional higher-level rle api? #897

ericphanson opened this issue Oct 2, 2023 · 0 comments

Comments

@ericphanson
Copy link

ericphanson commented Oct 2, 2023

I have been using the following abstraction over rle in some private code:

"""
    Run{T}

An object which represents a "run" of repeated values in a vector, produced by [`get_runs`](@ref).

See also: [`run_length`](@ref), [`run_value`](@ref), and [`run_indices`](@ref).
"""
struct Run{T}
    val::T
    indices::UnitRange{Int}
end
"""
    run_length(r::Run)

Returns the length of a given run.
"""
run_length(r::Run) = length(r.indices)
"""
    run_value(r::Run)

Returns the value of a given run.
"""
run_value(r::Run) = r.val

"""
    run_indices(r::Run)

Returns the indices (into the original vector) associated to the given run.
"""
run_indices(r::Run) = r.indices

"""
    get_runs(v) -> Vector{Run}

Performs a run-length encoding, returning a sorted vector of [`Run`](@ref) objects, each which support [`run_length`](@ref), [`run_value`](@ref), and [`run_indices`](@ref).
"""
function get_runs(v)
    vals, lens = rle(v)
    cs = cumsum(lens)
    run_starts = cs .- lens .+ 1
    run_stops = cs
    return [Run(vals[i], run_starts[i]:run_stops[i])
            for i in eachindex(vals, run_starts, run_stops)]
end

Would this be appropriate to upstream to StatsBase.jl (w/ tests)?

I have found this useful for manipulating boolean masks in-place, some examples below. I found the code written in terms of this api much simpler/easier to follow, since the lengths <-> index conversion is hidden away in get_runs, where it can be written/tested once and then used a bunch of times in downstream code.

Example usages
"""
    ranges_where_true(v) -> Vector{UnitRange{Int}}

Returns the sorted collection of indices where `v` has elements `true`, assuming `v` is a boolean mask.
"""
function ranges_where_true(v)
    runs = get_runs(v)
    filter!(run_value, runs) # only runs with value `true`
    return run_indices.(runs)
end

"""
    merge_mask_gaps!(v, N::Integer) -> v

Given a boolean mask `v`, merges "gaps" (i.e. runs of falses surrounded by trues) of length up to `N`, modifying `v` in-place.

## Example
```jldoctest
julia> show(merge_mask_gaps!([true, false, true, false], 1))
Bool[1, 1, 1, 0]
julia> show(merge_mask_gaps!([true, false, false, true, false], 1))
Bool[1, 0, 0, 1, 0]
julia> show(merge_mask_gaps!([true, false, false, true, false], 2))
Bool[1, 1, 1, 1, 0]

```
"""
function merge_mask_gaps!(v, N::Integer)
    @argcheck N >= 0
    N == 0 && return v
    runs = get_runs(v)

    # We want to find short 0 runs to flip. We only want to flip ones after the first non-zero run, and before the last non-zero run.
    f = findfirst(run_value, runs)
    l = findlast(run_value, runs)
    if !isnothing(f) && !isnothing(l)
        for i in f:l
            run = runs[i]
            if !run_value(run) && run_length(run) <= N
                v[run_indices(run)] .= true
            end
        end
    end
    return v
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant