Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when negative weights or zero sum are used when sampling #834

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

nalimilan
Copy link
Member

@nalimilan nalimilan commented Sep 3, 2022

These can give misleading results. Checking the sum is cheap, but checking for negative weights is relatively costly. Therefore, compute this information the first time it is requested, and store it in the weights vector like the sum.

efraimidis_ares_wsample_norep! and efraimidis_aexpj_wsample_norep! already checked these, but throwing different exception types. Harmonize exceptions across algorithms as they can all be called by sample.

I may well have been too defensive. I have no idea what are the exact requirements of each algorithm, but in doubt probably better throw errors. Are we even sure that zero weights are handled correctly everywhere? :-/

Current state of the PR doesn't make much sense on Julia 1.0 as Fix2 doesn't exist, I clean that later if we agree on the design (though the fast path won't work there).

These can give misleading results. Checking the sum is cheap, but
checking for negative weights is relatively costly. Therefore, compute
this information the first time it is requested, and store it in
the weights vector like the sum.

`efraimidis_ares_wsample_norep!` and `efraimidis_aexpj_wsample_norep!` already
checked these, but throwing different exception types. Harmonize exceptions
across algorithms as they can all be called by `sample`.
v
end

function Base.all(f::Base.Fix2{typeof(>=)}, wv::AbstractWeights)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be overkill, but unfortunately that's the standard way of checking whether all entries in a vector in Julia, so that's the only solution if we want external code to be able to use this feature, without exporting a new ispositive function. One advantage of defining this is that if we rework the API to take a weights keyword argument, sampling functions will be able to allow any AbstractArray and code will automatically work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. It's easy to not hit this function, e.g., when using all(x -> x >= 0, wv) etc. And, of course, it covers also things like all(>=(2), wv). One might also want to check non-negativity by e.g. !any(<(0), wv).

So I think a dedicated separate function would be cleaner and less ambiguous. If one wants to support AbstractArrays one could also at some point just define a fallback

isnonneg(x::AbstractArray{<:Real})  = all(>=(0), x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What annoys me is that there's no reason why this very basic function would live in and be exported by StatsBase. And anyway if users are not aware of the fast path (be it all(>=(0), x) or nonneg(x)), they won't use it and get the slow one, so I'm not sure choosing one syntax or the other makes a difference.

We could keep this internal for now -- though defining an internal function wouldn't be better than all(>=(0), wv) and !any(<(0), wv) as we would be sure users wouldn't be able to use it. ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would suggest only defining an internal function - that seems sufficient as it's only used in the argument checks internally. Something like isnnonneg seems to focus on what we actually want to check whereas defining all or any catches also other, in principle undesired, cases and hence requires a more complicated implementation with multiple checks.

test/weights.jl Outdated Show resolved Hide resolved
src/weights.jl Outdated Show resolved Hide resolved
src/weights.jl Outdated Show resolved Hide resolved
v
end

function Base.all(f::Base.Fix2{typeof(>=)}, wv::AbstractWeights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. It's easy to not hit this function, e.g., when using all(x -> x >= 0, wv) etc. And, of course, it covers also things like all(>=(2), wv). One might also want to check non-negativity by e.g. !any(<(0), wv).

So I think a dedicated separate function would be cleaner and less ambiguous. If one wants to support AbstractArrays one could also at some point just define a fallback

isnonneg(x::AbstractArray{<:Real})  = all(>=(0), x)

src/weights.jl Outdated Show resolved Hide resolved
src/weights.jl Outdated Show resolved Hide resolved
@@ -845,22 +856,22 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w < 0 && error("Negative weight found in weight vector at index $s")
w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these efraimidis functions were actually one of my first contributions to Julia packages. Fun times but I'm not surprisied that some things can be improved and made more consistent 😄

src/weights.jl Show resolved Hide resolved
@bkamins
Copy link
Contributor

bkamins commented Sep 4, 2022

@nalimilan - do we have benchmarks showing the performance impact?

Also - given the changes we do I think it would be good in https://juliastats.org/StatsBase.jl/stable/weights/#Weight-Vectors-1 to add a recommendation to create weights vector once if it is reused later. The point is that the code like:

[sample(Weights(x)) for i in 1:10^6]

is inefficient as it should be replaced by:

w = Weights(x)
[sample(w) for i in 1:10^6]

and I commonly see such patterns since users are unaware of the cost of Weights creation.

@nalimilan
Copy link
Member Author

Here's a simple benchmark for the worst case where the weights vector is used for the first time (so that negative entries have to be checked). I show both timings excluding and including the creation of the weights vector. (This covers the alias_sample! method.)

Unfortunately the overhead isn't completely negligible in this situation (up to 40% when sampling a single value from a 10,000 vector), which isn't surprising. So I've tried another approach: checking for negative weights directly inside sampling methods, when iterating over them. This actually gives a much lower overhead (up to 10%), see the last series of benchmarks for an illustration of what this gives for make_alias_table!. I haven't checked whether it would be fast for other methods too, but that's probably the case.

So we have three options:

  • consider that the overhead of the current state of the PR is acceptable: probably not
  • first check whether wv.negative === missing, and only check that weights are positive in that case: doable though a bit complex
  • always check that weights are positive, and drop the new field from AbstractWeights types: this is simpler but imposes the (small) overhead even for repeated sampling with the same weights vector

What I don't understand is why the benchmarks where the cost of creating weights(y) is included show a very large regression compared with master, much larger than when excluding that cost. And it persists with the third approach. This doesn't make any sense to me.

# master

julia> x = rand(100);

julia> y = rand(100);

julia> @btime sample(x, w) setup=(w=weights(y));
  84.847 ns (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  1.855 μs (5 allocations: 4.38 KiB)

julia> @btime sample(x, weights(y));
  105.413 ns (2 allocations: 48 bytes)

julia> @btime sample(x, weights(y), 100);
  1.870 μs (6 allocations: 4.41 KiB)

julia> x = rand(10_000);

julia> y = rand(10_000);

julia> @btime sample(x, w) setup=(w=weights(y));
  1.384 μs (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  112.887 μs (9 allocations: 313.56 KiB)

julia> @btime sample(x, weights(y));
  2.530 μs (2 allocations: 48 bytes)

julia> @btime sample(x, weights(y), 100);
  114.043 μs (10 allocations: 313.59 KiB)

# This PR

julia> x = rand(100);

julia> y = rand(100);

julia> @btime sample(x, w) setup=(w=weights(y));
  94.256 ns (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  2.249 μs (5 allocations: 4.38 KiB)

julia> @btime sample(x, weights(y));
  157.160 ns (2 allocations: 48 bytes)

julia> @btime sample(x, weights(y), 100);
  2.298 μs (6 allocations: 4.41 KiB)

julia> x = rand(10_000);

julia> y = rand(10_000);

julia> @btime sample(x, w) setup=(w=weights(y));
  1.882 μs (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  126.571 μs (9 allocations: 313.56 KiB)

julia> @btime sample(x, weights(y));
  4.723 μs (2 allocations: 48 bytes)

julia> @btime sample(x, weights(y), 100);
  131.106 μs (10 allocations: 313.59 KiB)

# Checking negative weights in make_alias_table! loop

julia> x = rand(100);

julia> y = rand(100);

julia> @btime sample(x, w) setup=(w=weights(y));
  93.747 ns (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  1.944 μs (5 allocations: 4.38 KiB)

julia> @btime sample(x, weights(y));
  157.153 ns (2 allocations: 48 bytes)

julia> @btime sample(x, weights(y), 100);
  2.115 μs (6 allocations: 4.41 KiB)

julia> x = rand(10_000);

julia> y = rand(10_000);

julia> @btime sample(x, w) setup=(w=weights(y));
  1.365 μs (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  120.666 μs (9 allocations: 313.56 KiB)

julia> @btime sample(x, weights(y));
  4.906 μs (2 allocations: 48 bytes)

julia> @btime sample(x, weights(y), 100);
  122.024 μs (10 allocations: 313.59 KiB)

@bkamins
Copy link
Contributor

bkamins commented Sep 4, 2022

@nalimilan - I do not understand this difference in timing:

# master

julia> x = rand(100);

julia> y = rand(100);

julia> @btime sample(x, w) setup=(w=weights(y));
  84.847 ns (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  1.855 μs (5 allocations: 4.38 KiB)

# This PR

julia> x = rand(100);

julia> y = rand(100);

julia> @btime sample(x, w) setup=(w=weights(y));
  94.256 ns (1 allocation: 16 bytes)

julia> @btime sample(x, w, 100) setup=(w=weights(y));
  2.249 μs (5 allocations: 4.38 KiB)

What is the reason that first time the overhead is 10ns, and in the second it is 0.4μs. I would think the difference should be the same?

@nalimilan
Copy link
Member Author

nalimilan commented Sep 4, 2022

Good catch. This seems to be due to direct_sample! calling sample(rng, wv) for each requested value, which means we check sum and negativity of weights 100 times. The code was maybe efficient when it was written, but we've progressively added sanity checks and now it's really absurd. Looks like by copying the code I can almost eliminate the overhead. I'll do more investigations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants