Skip to content

saglam/big-rust-preschool

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

What is this?

In this repo I keep pieces of code I write as I try to wrap my mind around Rust and its ownership system and take notes about my findings and experiences here in the README.md file. It is kind of a blog inside github, which I call a glog, which is an accurate name not only in the github + blog sense, but also in the blog + gulag sense.

I am just picking up Rust so the code here will not be the most idiomatic or expressive; though I hope it will be useful for others learning Rust. If you find any improvements to the code here or have suggestions I would be delighted to hear about them (my email is in my profile).

Contents

As a first goal, I want to solve an IOI2016 problem named molecules in Rust, since the solution I have in mind involves a generic algorithm as a subroutine and I will try to implement this subroutine as generally as possible using Rust generics and am curious to see how this interacts with the ownership system.

Molecules problem

Here is the problem statement:

We are given \inline n nonnegative integers \inline w_1, \inline w_2, \ldots, w_n\, and an inclusive range \inline [l, u], with the promise that \inline u - l \ge
  \mathrm{max}_i w_i - \mathrm{min}_i w_i. Find a subset of \inline w_i's which sum to a number in the range \inline [l, u] or if no such set exists, output the empty set.

Let us see how we can solve it. Assume for a moment that \inline w_i's are sorted in increasing order:

\displaystyle{
0 \le w_1 \le w_2 \le \cdots \le w_n.
}

Let \inline t be the largest index such that \inline w_1 + w_2 +\cdots + w_t \lt  l. If the sum of the first \inline t+1 elements is at most \inline u, then we have found a solution: a correct answer to the problem is \inline [1, t+1].

If on the other hand \inline w_1+ \cdots + w_{t+1} \gt  u, then we know that the smallest \inline t+1-element subset is too large and the smallest \inline t-element subset is too small. Therefore if there is a correct subset, it must have exactly \inline t elements.

Now consider the largest \inline t-element subset: \inline w_{n-t+1}, \ldots, w_{n}. If this subset sums to a value smaller than \inline l, then there are no solutions by our previous observation. If this subset sums to a value greater than or equal to \inline l on the other hand, a solution is guaranteed to exists and here is why: Let

\displaystyle{
S = \{1, \ldots, t\} \text{, and } T=\{n-t+1, \ldots, n\}.
}

As long as \inline \sum_{i\in S} w_i is smaller than \inline l, exchange the smallest index in \inline S with the largest index in \inline T. This exchange is guaranteed to increase the sum \inline \sum_{i\in S} w_i, but no single exchange can push the value from smaller than \inline l to larger than \inline u due to the promise \inline u - l \ge
\mathrm{max}_i w_i - \mathrm{min}_i w_i. So we have a value that is initially smaller than \inline l, and eventually becomes larger than \inline u and in no single step jumps from smaller than \inline l to larger than \inline u. This means that in at least one step the value must be in \inline [l,u]. (Think of it like this. We have to cross a river, with \inline u and \inline l being on the opposite sides; we can't jump over it; so we've got to swim through it 🎵)

This leads to a greedy algorithm in a direct way: Given the \inline w_i values, sort them, set \inline S = \{1,\ldots, t\}, \inline T=\{n-t+1,\ldots, n\} and update \inline S,T as described above until the sum is no smaller than \inline l in a single pass over the \inline w_i values and output \inline S. The runtime of this algorithm is \inline O(n\log n), dominated by the sorting subroutine. This is also the best runtime given by the IOI solutions.

Here we will solve this problem in \inline O(n) time—without sorting—by implementing the above idea more carefully. First we need a generic subroutine which I call the "sumth_element", which is a bit like the std::nth_element in C++ (a.k.a. quickselect) but is concerned with tail sums.

sumth_element

Here is the setup:

Suppose we are give an unsorted array \inline a_1, a_2,\ldots, a_n and a number \inline S. Find the largest \inline t such that the sum of the smallest \inline t elements of \inline a is at most \inline S.

I call this the "sumth_element" for lack of a better name and is meant to be in analogy with std::nth_element (if you know the proper name for this do let me know).

If the array was sorted and we had oracle access to prefix sums of \inline a, we could have found this index in \inline O(\log n) time by a binary search:

let mut l = 0;
let mut r = a.len();

while l < r {
    let m = l + (r - l) / 2;
    let tail_sum = a[l..=m].iter().sum(); // Assume an O(1) time oracle
    if tail_sum <= sum {
        l = m + 1;
        sum -= tail_sum;
    } else {
        r = m;
    }
}
l

Any binary search working on sorted arrays can be translated mechanically to work on arbitrary arrays in \inline O(n) time by calls to the quickselect algorithm. It looks like rust standard library does not have a quickselect implementation and I'm not even going to think about attempting to implement it (I don't think I have ever implemented it, in any language). It turns out there are several crates implementing this. I will go with the order-stat crate:

# Cargo.toml
[dependencies]
order-stat = "0.1"

Here is the version that does the same as the above binary search but works on arbitrary arrays:

let mut l = 0;
let mut r = a.len();

while l < r {
    let m = l + (r - l) / 2;
    order_stat::kth(&mut a[l..r], m - l); // added line
    let tail_sum = a[l..=m].iter().sum();
    if tail_sum <= sum {
        l = m + 1;
        sum -= tail_sum;
    } else {
        r = m;
    }
}
l

That is, wherever the binary search is about to query an index \inline m, we do a minimal amount of sorting so as to ensure the property required by the binary search:

\displaystyle{ a[i] \le a[m] \text{ for } i\lt m \text{ and } a[m] \le a[i] \text{ for } m\lt i.}

Each invocation of the order_stat::kth takes time linear in the size of the slice we pass to it. Since the slice size is halved in each iteration, the total runtime comes to \inline O(n).

Now let's do this generically. Here is the signature of the sumth_element function.

pub fn sumth_element<T, S>(a: &mut [T], mut sum: S) -> (usize, S)
where
    T: Ord,
    for<'a> S: Sum<&'a T> + SubAssign + Ord,

Here Ord is short for std::cmp::Ord and stands for 'totally ordered', which requires that any two \inline x,y\in T must satisfy either \inline x \lt  y or \inline x \gt  y, unless \inline x and \inline y are equal according to the std::cmp::PartialEq trait. We mutably borrow a slice of items of type T and we assume that the Ts can be summed to obtain an S. Since we mutably borrow the input slice and then mutably 'lend' it to stat_order::kth and immutably lend it to std::iter::Sum, it turns out, we need to annotate the lifetime in a special way using the so called Higher-Rank Trait Bounds, which is the last line above. You can find the full code in sumth_element.rs.

If you have cloned this repo, you can run the unit tests from repo root by

cargo test sumth_element --release

A linear time solution

Let us go back to molecues. Now with the sumth_element in hand, we can solve the problem in linear time as follows. First we map the \inline w_i array into the array of tuples \inline (w_i, i) since at the end we need to output the indices. We use u32 for both the "weights" \inline w_i and the indices so as to make memory layout as compact as possible for cache efficiency. However, we need u64s whenever we need to sum the weights since for large input instances u32 will overflow.

struct WeightIndex {
    pub w: u32,
    pub i: u32,
}

pub fn find_subset(l: u32, u: u32, w: &[u32]) -> Vec<u32> {
    let mut wi: Vec<_> = w
        .iter()
        .enumerate()
        .map(|(i, &w)| WeightIndex { w, i: i as u32 })
        .collect();

    let (t, slack) = sumth_element(&mut wi, l as u64 - 1);
    if t == wi.len() {
        return vec![];
    }

    let sum = l - 1 - slack as u32;
    order_stat::kth(&mut wi[t..], 0);
    if sum + wi[t].w <= u {
        return wi[..=t].into_iter().map(|wi| wi.i).collect();
    }

    if t + t + 1 < w.len() && t > 0 {
        order_stat::kth(&mut wi[t + 1..], w.len() - (t + t + 1));
    }

    let mut sum = sum;
    let mut j = 0;
    let mut k = wi.len() - 1;
    while sum < l && j < t {
        sum += wi[k].w - wi[j].w;
        wi.swap(j, k);
        j += 1;
        k -= 1;
    }
    if sum >= l {
        wi[..t].into_iter().map(|wi| wi.i).collect()
    } else {
        vec![]
    }
}

Rust is really expressive! I can't imagine transforming the weights array to WeightIndex array in C++ standard library. Here is the full solution: molecules.rs.

Official test suite

Let us finally try our solution on the official test suite. I am having a real tough time finding a good way to parse whitespace separated integers from a text file in a streaming fashion with the Rust standard library (or any crate I could find). The test files can be quite large, especially for problems with small complexity (such as \inline O(n\log n)), so a streaming reader is a necessity. Note that usually all data is in one line so the BufReader::lines() is not going to help with this. The best thing I could come up with, bench.rs ::test_from_file(), is not only super ugly, but is also suboptimal—it still involves an extra string copy per integer. This extra copying is tolerable in most languages, but feels silly when you can just parse the integer from the buffer directly and safely thanks to the borrow checker. I implemented the test runner as a cargo bench to get some timing info. You can run the benchmark from the repo root as so:

rustup install nightly
cargo +nightly bench

     Running target/release/deps/bench-a9da5db63e60670a

running 4 tests
test huge_tests   ... bench:  14,444,431 ns/iter (+/- 912,981)
test large_tests  ... bench: 172,175,577 ns/iter (+/- 950,302)
test medium_tests ... bench: 156,179,196 ns/iter (+/- 1,020,305)
test small_tests  ... bench:  34,498,891 ns/iter (+/- 409,675)

test result: ok. 0 passed; 0 failed; 0 ignored; 4 measured; 0 filtered out

The numbers are from my Intel(R) Celeron(R) 2955U @ 1.40GHz (beast of a) workstation.

Enter closures

In our implementation above, to keep track of the indices, we created a new type WeightIndex and described how to sum and compare instances of this type to the compiler by implementing the relevant traits. Given the input array of weights, we transformed it into the \inline (w_i, i) array and then again to indices array before we returned the answer.

Can we avoid these transformations altogether? Instead of working with the pairs \inline (w_i, i), let us work with the indices directly and describe how to sum and compare these indices via sum_fn and cmp_fn closures we pass around. To do that, first we need to create a version of sumth_element which not only takes a slice and a sum, but also a sum and compare closure. This I will call the sumth_element_with function. Here is the signature:

pub fn sumth_element_with<T, S, SumFn, CmpFn>(
    a: &mut [T],
    mut sum: S,
    sum_fn: SumFn,
    cmp_fn: CmpFn,
) -> (usize, S)
where
    S: SubAssign + Ord,
    SumFn: for<'a> Fn(&'a [T]) -> S,
    CmpFn: for<'a> Fn(&'a T, &'a T) -> Ordering,

With this version of the sumth_element, the solution turns into the following.

pub fn find_subset2(l: u32, u: u32, w: &[u32]) -> Vec<u32> {
    let sum_fn = |s: &[u32]| s.iter().fold(0u64, |sum, &i| sum + w[i as usize] as u64);

    let mut ind: Vec<u32> = (0..w.len() as u32).collect();
    let (t, slack) = sumth_element_with(&mut ind, l as u64 - 1, sum_fn, |&i, &j| {
        w[i as usize].cmp(&w[j as usize])
    });

    if t == w.len() {
        return vec![];
    }
    let sum = l - 1 - slack as u32;
    order_stat::kth_by(&mut ind[t..], 0, |&i, &j| w[i as usize].cmp(&w[j as usize]));
    if sum + w[ind[t] as usize] <= u {
        ind.truncate(t + 1);
        return ind;
    }

    if t + t + 1 < w.len() && t > 0 {
        order_stat::kth_by(&mut ind[t + 1..], w.len() - (t + t + 1), |&i, &j| {
            w[i as usize].cmp(&w[j as usize])
        });
    }

    let mut sum = sum;
    let mut j = 0;
    let mut k = w.len() - 1;
    while sum < l && j < t {
        sum += w[ind[k] as usize] - w[ind[j] as usize];
        ind.swap(j, k);
        j += 1;
        k -= 1;
    }
    if sum >= l {
        ind.truncate(t);
        ind
    } else {
        vec![]
    }
}

Let us compare how these two implementations in terms of speed.

     Running target/release/deps/bench-a9da5db63e60670a

running 8 tests
test huge_tests_1   ... bench:  13,805,424 ns/iter (+/- 58,706)
test huge_tests_2   ... bench:  14,423,895 ns/iter (+/- 65,680)
test large_tests_1  ... bench: 165,270,894 ns/iter (+/- 185,669)
test large_tests_2  ... bench: 178,570,020 ns/iter (+/- 170,748)
test medium_tests_1 ... bench: 149,847,087 ns/iter (+/- 172,004)
test medium_tests_2 ... bench: 166,938,070 ns/iter (+/- 330,868)
test small_tests_1  ... bench:  33,209,788 ns/iter (+/- 86,485)
test small_tests_2  ... bench:  36,145,058 ns/iter (+/- 99,747)

The new version is roughly 10% slower. Compared to the first version we lost a ton of memory locality: to make comparisons and summation we need random accesses to the \inline w array, whereas in the first version \inline w_i is always carried around with the indices. This probably explains the 10% slowdown.

I still have a lingering suspicion about the inlining failure of the closures, but I don't know enough about rust tooling to investigate this yet.

About

Just trying to wrap my mind around Rust

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages