From 00f9846ba5cb717b3e5392301029f9c46ecea527 Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Mon, 23 Aug 2021 07:11:40 -0700 Subject: [PATCH] fix: sampled suffix array (#447) * Fixed sampled suffix array sentinels bug and added more tests * Serde derive for sampled suffix array * Cargo fmt * Clarified comments * Allow Occ to be clonable * Make fm index accessible in sampled suffix array * Added unsafe fmd index contruction useful for special cases * Added random test cases --- src/data_structures/bwt.rs | 2 +- src/data_structures/fmindex.rs | 15 ++ src/data_structures/suffix_array.rs | 374 +++++++++++++++++----------- 3 files changed, 250 insertions(+), 141 deletions(-) diff --git a/src/data_structures/bwt.rs b/src/data_structures/bwt.rs index 6147d1b53..477ee3bbf 100644 --- a/src/data_structures/bwt.rs +++ b/src/data_structures/bwt.rs @@ -73,7 +73,7 @@ pub fn invert_bwt(bwt: &BWTSlice) -> Vec { } /// An occurrence array implementation. -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct Occ { occ: Vec>, k: u32, diff --git a/src/data_structures/fmindex.rs b/src/data_structures/fmindex.rs index d51b883f8..124f5a898 100644 --- a/src/data_structures/fmindex.rs +++ b/src/data_structures/fmindex.rs @@ -480,6 +480,21 @@ impl, DLess: Borrow, DOcc: Borrow> FMDIndex, + ) -> FMDIndex { + FMDIndex { fmindex } + } } #[cfg(test)] diff --git a/src/data_structures/suffix_array.rs b/src/data_structures/suffix_array.rs index b7aaabe26..886529d24 100644 --- a/src/data_structures/suffix_array.rs +++ b/src/data_structures/suffix_array.rs @@ -21,8 +21,11 @@ //! ); //! ``` +use std::borrow::Borrow; use std::cmp; +use std::collections::HashMap; use std::fmt::Debug; +use std::hash::BuildHasherDefault; use std::iter; use std::ops::Deref; @@ -32,76 +35,103 @@ use num_traits::{cast, NumCast, Unsigned}; use bv::{BitVec, Bits, BitsMut}; use vec_map::VecMap; +use fxhash::FxHasher; + +use serde::{Deserialize, Serialize}; + use crate::alphabets::{Alphabet, RankTransform}; +use crate::data_structures::bwt::{Less, Occ, BWT}; use crate::data_structures::smallints::SmallInts; pub type LCPArray = SmallInts; pub type RawSuffixArray = Vec; pub type RawSuffixArraySlice<'a> = &'a [usize]; +type HashMapFx = HashMap>; + /// A trait exposing general functionality of suffix arrays. pub trait SuffixArray { fn get(&self, index: usize) -> Option; fn len(&self) -> usize; fn is_empty(&self) -> bool; - // /// Sample the suffix array with the given sample rate. - // /// - // /// # Arguments - // /// - // /// * `bwt` - the corresponding BWT - // /// * `less` - the corresponding less array - // /// * `occ` - the corresponding occ table - // /// * `sampling_rate` - if sampling rate is k, every k-th entry will be kept - // /// - // /// # Example - // /// - // /// ``` - // /// use bio::data_structures::suffix_array::{suffix_array, SuffixArray}; - // /// use bio::data_structures::bwt::{bwt, less, Occ}; - // /// use bio::alphabets::dna; - // /// - // /// let text = b"ACGCGAT$"; - // /// let alphabet = dna::n_alphabet(); - // /// let sa = suffix_array(text); - // /// let bwt = bwt(text, &sa); - // /// let less = less(&bwt, &alphabet); - // /// let occ = Occ::new(&bwt, 3, &alphabet); - // /// let sampled = sa.sample(&bwt, &less, &occ, 1); - // /// - // /// for i in 0..sa.len() { - // /// assert_eq!(sa.get(i), sampled.get(i)); - // /// } - // /// ``` - // fn sample - // (&self, bwt: DBWT, less: DLess, occ: DOcc, sampling_rate: usize) -> - // SampledSuffixArray { - // - // let mut sample = Vec::with_capacity((self.len() as f32 / sampling_rate as f32).ceil() as usize); - // for i in 0..self.len() { - // if (i % sampling_rate) == 0 { - // sample.push(self.get(i).unwrap()); - // } - // } - // - // SampledSuffixArray { - // bwt: bwt, - // less: less, - // occ: occ, - // sample: sample, - // s: sampling_rate, - // } - // } + /// Sample the suffix array with the given sample rate. + /// + /// # Arguments + /// + /// * `text` - text that the suffix array is built on + /// * `bwt` - the corresponding BWT + /// * `less` - the corresponding less array + /// * `occ` - the corresponding occ table + /// * `sampling_rate` - if sampling rate is k, every k-th entry will be kept + /// + /// # Example + /// + /// ``` + /// use bio::alphabets::dna; + /// use bio::data_structures::bwt::{bwt, less, Occ}; + /// use bio::data_structures::suffix_array::{suffix_array, SuffixArray}; + /// + /// let text = b"ACGCGAT$"; + /// let alphabet = dna::n_alphabet(); + /// let sa = suffix_array(text); + /// let bwt = bwt(text, &sa); + /// let less = less(&bwt, &alphabet); + /// let occ = Occ::new(&bwt, 3, &alphabet); + /// let sampled = sa.sample(text, &bwt, &less, &occ, 2); + /// + /// for i in 0..sa.len() { + /// assert_eq!(sa.get(i), sampled.get(i)); + /// } + /// ``` + fn sample, DLess: Borrow, DOcc: Borrow>( + &self, + text: &[u8], + bwt: DBWT, + less: DLess, + occ: DOcc, + sampling_rate: usize, + ) -> SampledSuffixArray { + let mut sample = + Vec::with_capacity((self.len() as f32 / sampling_rate as f32).ceil() as usize); + let mut extra_rows = HashMapFx::default(); + let sentinel = sentinel(text); + + for i in 0..self.len() { + let idx = self.get(i).unwrap(); + if (i % sampling_rate) == 0 { + sample.push(idx); + } else if bwt.borrow()[i] == sentinel { + // If bwt lookup will return a sentinel + // Text suffixes that begin right after a sentinel are always saved as extra rows + // to help deal with FM index last to front inaccuracy when there are many sentinels + extra_rows.insert(i, idx); + } + } + + SampledSuffixArray { + bwt, + less, + occ, + sample, + s: sampling_rate, + extra_rows, + sentinel, + } + } } -// /// A sampled suffix array. -// pub struct SampledSuffixArray { -// bwt: DBWT, -// less: DLess, -// occ: DOcc, -// sample: Vec, -// s: usize, // Rate of sampling -// } +/// A sampled suffix array. +#[derive(Clone, Serialize, Deserialize)] +pub struct SampledSuffixArray, DLess: Borrow, DOcc: Borrow> { + bwt: DBWT, + less: DLess, + occ: DOcc, + sample: Vec, + s: usize, // Rate of sampling + extra_rows: HashMapFx, + sentinel: u8, +} impl SuffixArray for RawSuffixArray { fn get(&self, index: usize) -> Option { @@ -120,58 +150,72 @@ impl SuffixArray for RawSuffixArray { fn is_empty(&self) -> bool { Vec::is_empty(self) } +} - // fn sample - // (&self, bwt: DBWT, less: DLess, occ: DOcc, sampling_rate: usize) -> - // SampledSuffixArray { - // // Provide a specialized, faster implementation using iterators. - // - // let sample = self.iter().cloned().step(sampling_rate).collect(); - // - // SampledSuffixArray { - // bwt: bwt, - // less: less, - // occ: occ, - // sample: sample, - // s: sampling_rate, - // } - // } +impl, DLess: Borrow, DOcc: Borrow> SuffixArray + for SampledSuffixArray +{ + fn get(&self, index: usize) -> Option { + if index < self.len() { + let mut pos = index; + let mut offset = 0; + loop { + if pos % self.s == 0 { + return Some(self.sample[pos / self.s] + offset); + } + + let c = self.bwt.borrow()[pos]; + + if c == self.sentinel { + // Check if next character in the bwt is the sentinel + // If so, there must be a cached result to workaround FM index last to front + // mapping inaccuracy when there are multiple sentinels + // This branch should rarely be triggered so the performance impact + // of hashmap lookups would be low + return Some(self.extra_rows[&pos] + offset); + } + + pos = self.less.borrow()[c as usize] + + self.occ.borrow().get(self.bwt.borrow(), pos - 1, c); + offset += 1; + } + } else { + None + } + } + + fn len(&self) -> usize { + self.bwt.borrow().len() + } + + fn is_empty(&self) -> bool { + self.bwt.borrow().is_empty() + } } -// impl SuffixArray for SampledSuffixArray { -// fn get(&self, index: usize) -> Option { -// if index < self.len() { -// let mut pos = index; -// let mut offset = 0; -// loop { -// if pos % self.s == 0 { -// return Some(self.sample[pos / self.s] + offset); -// } -// -// let c = self.bwt[pos]; -// pos = self.less[c as usize] + self.occ.get(&self.bwt, pos - 1, c); -// offset += 1; -// } -// } else { -// None -// } -// } -// -// fn len(&self) -> usize { -// self.bwt.len() -// } - -// fn is_empty(&self) -> bool { -// self.bwt.is_empty() -// } -// } -// -// -// impl SampledSuffixArray { -// pub fn sampling_rate(&self) -> usize { -// self.s -// } -// } +impl, DLess: Borrow, DOcc: Borrow> + SampledSuffixArray +{ + /// Get the sampling rate of the suffix array. + pub fn sampling_rate(&self) -> usize { + self.s + } + + /// Get a reference to the internal BWT. + pub fn bwt(&self) -> &BWT { + self.bwt.borrow() + } + + /// Get a reference to the internal Less. + pub fn less(&self) -> &Less { + self.less.borrow() + } + + /// Get a reference to the internal Occ. + pub fn occ(&self) -> &Occ { + self.occ.borrow() + } +} /// Construct suffix array for given text of length n. /// Complexity: O(n). @@ -679,13 +723,13 @@ impl PosTypes { #[cfg(test)] mod tests { - // Commented-out imports waiting on re-enabling of sampled suffix array - // See issue #70 use super::*; use super::{transform_text, PosTypes, SAIS}; - use crate::alphabets::Alphabet; + use crate::alphabets::{dna, Alphabet}; + use crate::data_structures::bwt::{bwt, less}; use bv::{BitVec, BitsPush}; - //use data_structures::bwt::{bwt, less, Occ}; + use rand; + use rand::prelude::*; use std::str; #[test] @@ -778,9 +822,27 @@ mod tests { ) + "$" } + fn rand_seqs(num_seqs: usize, seq_len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let alpha = [b'A', b'T', b'C', b'G', b'N']; + let seqs = (0..num_seqs) + .into_iter() + .map(|_| { + let len = rng.gen_range((seq_len / 2)..=seq_len); + (0..len) + .into_iter() + .map(|_| *alpha.choose(&mut rng).unwrap()) + .collect::>() + }) + .collect::>(); + let mut res = seqs.join(&b'$'); + res.push(b'$'); + res + } + #[test] fn test_sorts_lexically() { - let test_cases = [(&b"A$C$G$T$"[..], "simple"), + let mut test_cases = vec![(&b"A$C$G$T$"[..], "simple"), (&b"A$A$T$T$"[..], "duplicates"), (&b"AA$GA$CA$TA$TC$TG$GT$GC$"[..], "two letter"), (&b"AGCCAT$\ @@ -795,6 +857,14 @@ mod tests { TATTGGAATAGCCATTCGACGCTGACCTCTTGAGGTTCCATTACCCGGCTACTGATGCTAAAATCCTGGCAGCCCGAGCAATACGAAATGTCCGCTGATT$"[..], "complex with revcomps"), ]; + let num_rand = 100; + let rand_cases = (0..num_rand) + .into_iter() + .map(|i| rand_seqs(10, i * 10)) + .collect::>(); + for i in 0..num_rand { + test_cases.push((&rand_cases[i], "rand test case")); + } for &(text, test_name) in test_cases.iter() { let pos = suffix_array(text); @@ -817,35 +887,59 @@ mod tests { } } - // #[test] - // fn test_sampled_matches() { - // let test_cases = [(&b"A$C$G$T$"[..], "simple"), - // (&b"A$A$T$T$"[..], "duplicates"), - // (&b"AA$GA$CA$TA$TC$TG$GT$GC$"[..], "two letter"), - // (&b"AGCCAT$\ - // CAGCC$"[..], - // "substring"), - // (&b"GTAGGCCTAATTATAATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAA$\ - // AATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAATGGCTATTCCAATA$"[..], - // "complex"), - // (&b"GTAGGCCTAATTATAATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAA$\ - // TTCGACGCTGACCTCTTGAGGTTCCATTACCCGGCTACTGATGCTAAAATCCTGGCAGCCCGAGCAATACGAAATGTCCGCTGATTATAATTAGGCCTAC$\ - // AATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAATGGCTATTCCAATA$\ - // TATTGGAATAGCCATTCGACGCTGACCTCTTGAGGTTCCATTACCCGGCTACTGATGCTAAAATCCTGGCAGCCCGAGCAATACGAAATGTCCGCTGATT$"[..], - // "complex with revcomps"), - // ]; - // - // for &(text, _) in test_cases.into_iter() { - // let alphabet = dna::n_alphabet(); - // let sa = suffix_array(text); - // let bwt = bwt(text, &sa); - // let less = less(&bwt, &alphabet); - // let occ = Occ::new(&bwt, 3, &alphabet); - // let sampled = sa.sample(&bwt, &less, &occ, 2); - // - // for i in 0..sa.len() { - // assert_eq!(sa.get(i), sampled.get(i)); - // } - // } - // } + #[test] + fn test_sampled_matches() { + let mut test_cases = vec![(&b"A$C$G$T$"[..], "simple"), + (&b"A$A$T$T$"[..], "duplicates"), + (&b"AA$GA$CA$TA$TC$TG$GT$GC$"[..], "two letter"), + (&b"AGCCAT$\ + CAGCC$"[..], + "substring"), + (&b"GTAGGCCTAATTATAATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAA$\ + AATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAATGGCTATTCCAATA$"[..], + "complex"), + (&b"GTAGGCCTAATTATAATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAA$\ + TTCGACGCTGACCTCTTGAGGTTCCATTACCCGGCTACTGATGCTAAAATCCTGGCAGCCCGAGCAATACGAAATGTCCGCTGATTATAATTAGGCCTAC$\ + AATCAGCGGACATTTCGTATTGCTCGGGCTGCCAGGATTTTAGCATCAGTAGCCGGGTAATGGAACCTCAAGAGGTCAGCGTCGAATGGCTATTCCAATA$\ + TATTGGAATAGCCATTCGACGCTGACCTCTTGAGGTTCCATTACCCGGCTACTGATGCTAAAATCCTGGCAGCCCGAGCAATACGAAATGTCCGCTGATT$"[..], + "complex with revcomps"), + (&b"GTAG$GCCTAAT$TATAATCAG$"[..], "issue70"), + (&b"TGTGTGTGTG$"[..], "repeating"), + (&b"TACTCCGCTAGGGACACCTAAATAGATACTCGCAAAGGCGACTGATATATCCTTAGGTCGAAGAGATACCAGAGAAATAGTAGGTCTTAGGCTAGTCCTT$AAGGACTAGCCTAAGACCTACTATTTCTCTGGTATCTCTTCGACCTAAGGATATATCAGTCGCCTTTGCGAGTATCTATTTAGGTGTCCCTAGCGGAGTA$TAGGGACACCTAAATAGATACTCGCAAAGGCGACTGATATATCCTTAGGTCGAAGAGATACCAGAGAAATAGTAGGTCTTAGGCTAGTCCTTGTCCAGTA$TACTGGACAAGGACTAGCCTAAGACCTACTATTTCTCTGGTATCTCTTCGACCTAAGGATATATCAGTCGCCTTTGCGAGTATCTATTTAGGTGTCCCTA$ACGCACCCCGGCATTCGTCGACTCTACACTTAGTGGAACATACAAATTCGCTCGCAGGAGCGCCTCATACATTCTAACGCAGTGATCTTCGGCTGAGACT$AGTCTCAGCCGAAGATCACTGCGTTAGAATGTATGAGGCGCTCCTGCGAGCGAATTTGTATGTTCCACTAAGTGTAGAGTCGACGAATGCCGGGGTGCGT$"[..], "complex sentinels"), + ]; + let num_rand = 100; + let rand_cases = (0..num_rand) + .into_iter() + .map(|i| rand_seqs(10, i * 10)) + .collect::>(); + for i in 0..num_rand { + test_cases.push((&rand_cases[i], "rand test case")); + } + + for &(text, test_name) in test_cases.iter() { + for &sample_rate in &[2, 3, 5, 16] { + let alphabet = dna::n_alphabet(); + let sa = suffix_array(text); + let bwt = bwt(text, &sa); + let less = less(&bwt, &alphabet); + let occ = Occ::new(&bwt, 3, &alphabet); + let sampled = sa.sample(text, &bwt, &less, &occ, sample_rate); + + for i in 0..sa.len() { + let sa_idx = sa.get(i).unwrap(); + let sampled_idx = sampled.get(i).unwrap(); + assert_eq!( + sa_idx, + sampled_idx, + "Failed:\n{}\n{}\nat index {} do not match in test: {} (sample rate: {})", + str::from_utf8(&text[sa_idx..]).unwrap(), + str::from_utf8(&text[sampled_idx..]).unwrap(), + i, + test_name, + sample_rate + ); + } + } + } + } }