diff --git a/src/data_structures/bwt.rs b/src/data_structures/bwt.rs index 477ee3bbf..e67aa0513 100644 --- a/src/data_structures/bwt.rs +++ b/src/data_structures/bwt.rs @@ -83,8 +83,8 @@ impl Occ { /// Calculate occ array with sampling from BWT of length n. /// Time complexity: O(n). /// Space complexity: O(n / k * A) with A being the alphabet size. - /// Alphabet size is determined on the fly from the BWT. - /// For large texts, it is therefore advisable to transform + /// The specified alphabet must match the alphabet of the text and its BWT. + /// For large texts, it is advisable to transform /// the text before calculating the BWT (see alphabets::rank_transform). /// /// # Arguments @@ -97,12 +97,27 @@ impl Occ { .max_symbol() .expect("Expecting non-empty alphabet.") as usize + 1; - let mut occ = Vec::with_capacity(n / k as usize); - let mut curr_occ: Vec = repeat(0).take(m).collect(); + let mut alpha = alphabet.symbols.iter().collect::>(); + // include sentinel '$' + if (b'$' as usize) < m && !alphabet.is_word(b"$") { + alpha.push(b'$' as usize); + } + let mut occ: Vec> = vec![Vec::new(); m]; + let mut curr_occ = vec![0usize; m]; + + // characters not in the alphabet won't take up much space + for &a in &alpha { + occ[a].reserve(n / k as usize); + } + for (i, &c) in bwt.iter().enumerate() { curr_occ[c as usize] += 1; + if i % k as usize == 0 { - occ.push(curr_occ.clone()); + // only visit characters in the alphabet + for &a in &alpha { + occ[a].push(curr_occ[a]); + } } } @@ -122,7 +137,7 @@ impl Occ { // .iter() // .filter(|&&c| c == a) // .count(); - // self.occ[i][a as usize] + count + // self.occ[a as usize][i] + count // ``` // // But there are a couple of reasons to do this manually: @@ -140,15 +155,13 @@ impl Occ { // self.k is our sampling rate, so find the checkpoints either side of r. let lo_checkpoint = r / self.k as usize; // Get the occurences at the low checkpoint - let lo_occ = self.occ[lo_checkpoint][a as usize]; + let lo_occ = self.occ[a as usize][lo_checkpoint]; // If the sampling rate is infrequent it is worth checking if there is a closer // hi checkpoint. if self.k > 64 { let hi_checkpoint = lo_checkpoint + 1; - if let Some(hi_occs) = self.occ.get(hi_checkpoint) { - let hi_occ = hi_occs[a as usize]; - + if let Some(&hi_occ) = self.occ[a as usize].get(hi_checkpoint) { // Its possible that there are no occurences between the low and high // checkpoint in which case we bail early. if lo_occ == hi_occ { @@ -158,7 +171,6 @@ impl Occ { // If r is closer to the high checkpoint, count backwards from there. let hi_idx = hi_checkpoint * self.k as usize; if (hi_idx - r) < (self.k as usize / 2) { - let hi_occ = hi_occs[a as usize]; return hi_occ - bytecount::count(&bwt[r + 1..=hi_idx], a) as usize; } } @@ -232,7 +244,7 @@ mod tests { let bwt = vec![1u8, 3u8, 3u8, 1u8, 2u8, 0u8]; let alphabet = Alphabet::new(&[0u8, 1u8, 2u8, 3u8]); let occ = Occ::new(&bwt, 3, &alphabet); - assert_eq!(occ.occ, [[0, 1, 0, 0], [0, 2, 0, 2]]); + assert_eq!(occ.occ, [[0, 0], [1, 2], [0, 0], [0, 2]]); assert_eq!(occ.get(&bwt, 4, 2u8), 1); assert_eq!(occ.get(&bwt, 4, 3u8), 2); } @@ -240,7 +252,11 @@ mod tests { #[test] fn test_occwm() { let text = b"GCCTTAACATTATTACGCCTA$"; - let alphabet = dna::n_alphabet(); + let alphabet = { + let mut a = dna::n_alphabet(); + a.insert(b'$'); + a + }; let sa = suffix_array(text); let bwt = bwt(text, &sa); let occ = Occ::new(&bwt, 3, &alphabet);