In this repo I keep pieces of code I write as I try to wrap my mind
around Rust and its ownership system and take notes about my findings and
experiences here in the README.md
file. It is kind of a blog inside
github, which I call a glog, which is an accurate name not only in the
github + blog sense, but also in the blog + gulag sense.
I am just picking up Rust so the code here will not be the most idiomatic or expressive; though I hope it will be useful for others learning Rust. If you find any improvements to the code here or have suggestions I would be delighted to hear about them (my email is in my profile).
Contents
As a first goal, I want to solve an IOI2016 problem named molecules
in Rust,
since the solution I have in mind involves a generic algorithm
as a subroutine and I will try to implement this subroutine as generally as
possible using Rust generics and am curious to see how this interacts with the
ownership system.
Here is the problem statement:
We are given nonnegative integers , and an inclusive range , with the promise that . Find a subset of 's which sum to a number in the range or if no such set exists, output the empty set.
Let us see how we can solve it. Assume for a moment that 's are sorted in increasing order:
Let be the largest index such that . If the sum of the first elements is at most , then we have found a solution: a correct answer to the problem is .
If on the other hand , then we know that the smallest -element subset is too large and the smallest -element subset is too small. Therefore if there is a correct subset, it must have exactly elements.
Now consider the largest -element subset: . If this subset sums to a value smaller than , then there are no solutions by our previous observation. If this subset sums to a value greater than or equal to on the other hand, a solution is guaranteed to exists and here is why: Let
As long as is smaller than , exchange the smallest index in with the largest index in . This exchange is guaranteed to increase the sum , but no single exchange can push the value from smaller than to larger than due to the promise . So we have a value that is initially smaller than , and eventually becomes larger than and in no single step jumps from smaller than to larger than . This means that in at least one step the value must be in . (Think of it like this. We have to cross a river, with and being on the opposite sides; we can't jump over it; so we've got to swim through it 🎵)
This leads to a greedy algorithm in a direct way: Given the values, sort them, set , and update as described above until the sum is no smaller than in a single pass over the values and output . The runtime of this algorithm is , dominated by the sorting subroutine. This is also the best runtime given by the IOI solutions.
Here we will solve this problem in time—without sorting—by implementing
the above idea more carefully. First we need a generic subroutine which I call
the "sumth_element", which is a bit like the std::nth_element
in C++
(a.k.a. quickselect) but is concerned with tail sums.
Here is the setup:
Suppose we are give an unsorted array and a number . Find the largest such that the sum of the smallest elements of is at most .
I call this the "sumth_element" for lack of a better name and is meant to be in
analogy with std::nth_element
(if you know the proper name for this do let me know).
If the array was sorted and we had oracle access to prefix sums of , we could have found this index in time by a binary search:
let mut l = 0;
let mut r = a.len();
while l < r {
let m = l + (r - l) / 2;
let tail_sum = a[l..=m].iter().sum(); // Assume an O(1) time oracle
if tail_sum <= sum {
l = m + 1;
sum -= tail_sum;
} else {
r = m;
}
}
l
Any binary search working on sorted arrays can be translated mechanically to
work on arbitrary arrays in time by calls to the quickselect algorithm.
It looks like rust standard library does not have a quickselect
implementation and I'm not even going to think about attempting to implement
it (I don't think I have ever implemented it, in any language).
It turns out there are several crates implementing this. I will go with the
order-stat
crate:
# Cargo.toml
[dependencies]
order-stat = "0.1"
Here is the version that does the same as the above binary search but works on arbitrary arrays:
let mut l = 0;
let mut r = a.len();
while l < r {
let m = l + (r - l) / 2;
order_stat::kth(&mut a[l..r], m - l); // added line
let tail_sum = a[l..=m].iter().sum();
if tail_sum <= sum {
l = m + 1;
sum -= tail_sum;
} else {
r = m;
}
}
l
That is, wherever the binary search is about to query an index , we do a minimal amount of sorting so as to ensure the property required by the binary search:
Each invocation of the order_stat::kth
takes time linear in the size of the
slice we pass to it. Since the slice size is halved in each iteration, the total
runtime comes to .
Now let's do this generically. Here is the signature of the sumth_element
function.
pub fn sumth_element<T, S>(a: &mut [T], mut sum: S) -> (usize, S)
where
T: Ord,
for<'a> S: Sum<&'a T> + SubAssign + Ord,
Here Ord
is short for std::cmp::Ord
and stands for 'totally ordered', which
requires that any two must satisfy either or ,
unless and are equal according to the std::cmp::PartialEq
trait.
We mutably borrow a slice of items of type T
and we assume that the T
s can
be summed to obtain an S
.
Since we mutably borrow the input slice and then mutably 'lend' it to
stat_order::kth
and immutably lend it to std::iter::Sum
,
it turns out, we need to annotate the lifetime in a special way
using the so called Higher-Rank Trait Bounds, which is the last line above.
You can find the full code in sumth_element.rs.
If you have cloned this repo, you can run the unit tests from repo root by
cargo test sumth_element --release
Let us go back to molecues
. Now with the sumth_element
in hand, we can
solve the problem in linear time as follows. First we map the array into
the array of tuples since at the end we need to output the indices.
We use u32
for both the "weights" and the indices so as to make memory
layout as compact as possible for cache efficiency. However, we need u64
s
whenever we need to sum the weights since for large input instances u32
will
overflow.
struct WeightIndex {
pub w: u32,
pub i: u32,
}
pub fn find_subset(l: u32, u: u32, w: &[u32]) -> Vec<u32> {
let mut wi: Vec<_> = w
.iter()
.enumerate()
.map(|(i, &w)| WeightIndex { w, i: i as u32 })
.collect();
let (t, slack) = sumth_element(&mut wi, l as u64 - 1);
if t == wi.len() {
return vec![];
}
let sum = l - 1 - slack as u32;
order_stat::kth(&mut wi[t..], 0);
if sum + wi[t].w <= u {
return wi[..=t].into_iter().map(|wi| wi.i).collect();
}
if t + t + 1 < w.len() && t > 0 {
order_stat::kth(&mut wi[t + 1..], w.len() - (t + t + 1));
}
let mut sum = sum;
let mut j = 0;
let mut k = wi.len() - 1;
while sum < l && j < t {
sum += wi[k].w - wi[j].w;
wi.swap(j, k);
j += 1;
k -= 1;
}
if sum >= l {
wi[..t].into_iter().map(|wi| wi.i).collect()
} else {
vec![]
}
}
Rust is really expressive! I can't imagine transforming the weights array to WeightIndex array in C++ standard library. Here is the full solution: molecules.rs.
Let us finally try our solution on the official test suite.
I am having a real tough time finding a good way to parse whitespace separated
integers from a text file in a streaming fashion with the Rust standard library
(or any crate I could find).
The test files can be quite large, especially for problems with small
complexity (such as ), so a streaming reader is a necessity.
Note that usually all data is in one line so the BufReader::lines()
is
not going to help with this.
The best thing I could come up with, bench.rs
::test_from_file()
, is not only super ugly, but is also suboptimal—it still
involves an extra string copy per integer. This extra copying is tolerable
in most languages, but feels silly when you can just parse the integer from the
buffer directly and safely thanks to the borrow checker.
I implemented the test runner as a cargo bench to get some timing info.
You can run the benchmark from the repo root as so:
rustup install nightly
cargo +nightly bench
Running target/release/deps/bench-a9da5db63e60670a
running 4 tests
test huge_tests ... bench: 14,444,431 ns/iter (+/- 912,981)
test large_tests ... bench: 172,175,577 ns/iter (+/- 950,302)
test medium_tests ... bench: 156,179,196 ns/iter (+/- 1,020,305)
test small_tests ... bench: 34,498,891 ns/iter (+/- 409,675)
test result: ok. 0 passed; 0 failed; 0 ignored; 4 measured; 0 filtered out
The numbers are from my Intel(R) Celeron(R) 2955U @ 1.40GHz (beast of a) workstation.
In our implementation above, to keep track of the indices, we created
a new type WeightIndex
and described how to sum and compare instances of this
type to the compiler by implementing the relevant traits. Given the input
array of weights, we transformed it into the array and then again
to indices array before we returned the answer.
Can we avoid these transformations altogether? Instead of working with
the pairs , let us work with the indices directly and describe
how to sum and compare these indices via sum_fn
and cmp_fn
closures we pass
around. To do that, first we need to create a version of sumth_element
which
not only takes a slice and a sum, but also a sum and compare closure. This I
will call the sumth_element_with
function. Here is the signature:
pub fn sumth_element_with<T, S, SumFn, CmpFn>(
a: &mut [T],
mut sum: S,
sum_fn: SumFn,
cmp_fn: CmpFn,
) -> (usize, S)
where
S: SubAssign + Ord,
SumFn: for<'a> Fn(&'a [T]) -> S,
CmpFn: for<'a> Fn(&'a T, &'a T) -> Ordering,
With this version of the sumth_element
, the solution turns into the following.
pub fn find_subset2(l: u32, u: u32, w: &[u32]) -> Vec<u32> {
let sum_fn = |s: &[u32]| s.iter().fold(0u64, |sum, &i| sum + w[i as usize] as u64);
let mut ind: Vec<u32> = (0..w.len() as u32).collect();
let (t, slack) = sumth_element_with(&mut ind, l as u64 - 1, sum_fn, |&i, &j| {
w[i as usize].cmp(&w[j as usize])
});
if t == w.len() {
return vec![];
}
let sum = l - 1 - slack as u32;
order_stat::kth_by(&mut ind[t..], 0, |&i, &j| w[i as usize].cmp(&w[j as usize]));
if sum + w[ind[t] as usize] <= u {
ind.truncate(t + 1);
return ind;
}
if t + t + 1 < w.len() && t > 0 {
order_stat::kth_by(&mut ind[t + 1..], w.len() - (t + t + 1), |&i, &j| {
w[i as usize].cmp(&w[j as usize])
});
}
let mut sum = sum;
let mut j = 0;
let mut k = w.len() - 1;
while sum < l && j < t {
sum += w[ind[k] as usize] - w[ind[j] as usize];
ind.swap(j, k);
j += 1;
k -= 1;
}
if sum >= l {
ind.truncate(t);
ind
} else {
vec![]
}
}
Let us compare how these two implementations in terms of speed.
Running target/release/deps/bench-a9da5db63e60670a
running 8 tests
test huge_tests_1 ... bench: 13,805,424 ns/iter (+/- 58,706)
test huge_tests_2 ... bench: 14,423,895 ns/iter (+/- 65,680)
test large_tests_1 ... bench: 165,270,894 ns/iter (+/- 185,669)
test large_tests_2 ... bench: 178,570,020 ns/iter (+/- 170,748)
test medium_tests_1 ... bench: 149,847,087 ns/iter (+/- 172,004)
test medium_tests_2 ... bench: 166,938,070 ns/iter (+/- 330,868)
test small_tests_1 ... bench: 33,209,788 ns/iter (+/- 86,485)
test small_tests_2 ... bench: 36,145,058 ns/iter (+/- 99,747)
The new version is roughly 10% slower. Compared to the first version we lost a ton of memory locality: to make comparisons and summation we need random accesses to the array, whereas in the first version is always carried around with the indices. This probably explains the 10% slowdown.
I still have a lingering suspicion about the inlining failure of the closures, but I don't know enough about rust tooling to investigate this yet.