Skip to content

Commit

Permalink
fix: One Simple Trick to significantly improve the speed and memory u…
Browse files Browse the repository at this point in the history
…sage of Occ (#448)

* Improve Occ array speed and memory usage

* Always try to include sentinel even when it is not in alphabet
  • Loading branch information
Daniel-Liu-c0deb0t committed Aug 23, 2021
1 parent 00f9846 commit 9aa79cb
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions src/data_structures/bwt.rs
Expand Up @@ -83,8 +83,8 @@ impl Occ {
/// Calculate occ array with sampling from BWT of length n.
/// Time complexity: O(n).
/// Space complexity: O(n / k * A) with A being the alphabet size.
/// Alphabet size is determined on the fly from the BWT.
/// For large texts, it is therefore advisable to transform
/// The specified alphabet must match the alphabet of the text and its BWT.
/// For large texts, it is advisable to transform
/// the text before calculating the BWT (see alphabets::rank_transform).
///
/// # Arguments
Expand All @@ -97,12 +97,27 @@ impl Occ {
.max_symbol()
.expect("Expecting non-empty alphabet.") as usize
+ 1;
let mut occ = Vec::with_capacity(n / k as usize);
let mut curr_occ: Vec<usize> = repeat(0).take(m).collect();
let mut alpha = alphabet.symbols.iter().collect::<Vec<usize>>();
// include sentinel '$'
if (b'$' as usize) < m && !alphabet.is_word(b"$") {
alpha.push(b'$' as usize);
}
let mut occ: Vec<Vec<usize>> = vec![Vec::new(); m];
let mut curr_occ = vec![0usize; m];

// characters not in the alphabet won't take up much space
for &a in &alpha {
occ[a].reserve(n / k as usize);
}

for (i, &c) in bwt.iter().enumerate() {
curr_occ[c as usize] += 1;

if i % k as usize == 0 {
occ.push(curr_occ.clone());
// only visit characters in the alphabet
for &a in &alpha {
occ[a].push(curr_occ[a]);
}
}
}

Expand All @@ -122,7 +137,7 @@ impl Occ {
// .iter()
// .filter(|&&c| c == a)
// .count();
// self.occ[i][a as usize] + count
// self.occ[a as usize][i] + count
// ```
//
// But there are a couple of reasons to do this manually:
Expand All @@ -140,15 +155,13 @@ impl Occ {
// self.k is our sampling rate, so find the checkpoints either side of r.
let lo_checkpoint = r / self.k as usize;
// Get the occurences at the low checkpoint
let lo_occ = self.occ[lo_checkpoint][a as usize];
let lo_occ = self.occ[a as usize][lo_checkpoint];

// If the sampling rate is infrequent it is worth checking if there is a closer
// hi checkpoint.
if self.k > 64 {
let hi_checkpoint = lo_checkpoint + 1;
if let Some(hi_occs) = self.occ.get(hi_checkpoint) {
let hi_occ = hi_occs[a as usize];

if let Some(&hi_occ) = self.occ[a as usize].get(hi_checkpoint) {
// Its possible that there are no occurences between the low and high
// checkpoint in which case we bail early.
if lo_occ == hi_occ {
Expand All @@ -158,7 +171,6 @@ impl Occ {
// If r is closer to the high checkpoint, count backwards from there.
let hi_idx = hi_checkpoint * self.k as usize;
if (hi_idx - r) < (self.k as usize / 2) {
let hi_occ = hi_occs[a as usize];
return hi_occ - bytecount::count(&bwt[r + 1..=hi_idx], a) as usize;
}
}
Expand Down Expand Up @@ -232,15 +244,19 @@ mod tests {
let bwt = vec![1u8, 3u8, 3u8, 1u8, 2u8, 0u8];
let alphabet = Alphabet::new(&[0u8, 1u8, 2u8, 3u8]);
let occ = Occ::new(&bwt, 3, &alphabet);
assert_eq!(occ.occ, [[0, 1, 0, 0], [0, 2, 0, 2]]);
assert_eq!(occ.occ, [[0, 0], [1, 2], [0, 0], [0, 2]]);
assert_eq!(occ.get(&bwt, 4, 2u8), 1);
assert_eq!(occ.get(&bwt, 4, 3u8), 2);
}

#[test]
fn test_occwm() {
let text = b"GCCTTAACATTATTACGCCTA$";
let alphabet = dna::n_alphabet();
let alphabet = {
let mut a = dna::n_alphabet();
a.insert(b'$');
a
};
let sa = suffix_array(text);
let bwt = bwt(text, &sa);
let occ = Occ::new(&bwt, 3, &alphabet);
Expand Down

0 comments on commit 9aa79cb

Please sign in to comment.