diff --git a/src/stats/probs/adaptive_integration.rs b/src/stats/probs/adaptive_integration.rs new file mode 100644 index 000000000..4555d8617 --- /dev/null +++ b/src/stats/probs/adaptive_integration.rs @@ -0,0 +1,146 @@ +// Copyright 2021-2022 Johannes Köster. +// Licensed under the MIT license (http://opensource.org/licenses/MIT) +// This file may not be copied, modified, or distributed +// except according to those terms. + +use std::cmp; +use std::collections::HashMap; +use std::convert::Into; +use std::hash::Hash; +use std::{ + fmt::Debug, + ops::{Add, Div, Mul, Sub}, +}; + +use crate::stats::probs::LogProb; +use itertools::Itertools; +use itertools_num::linspace; +use ordered_float::NotNan; + +/// Integrate over an interval of type T with a given density function while trying to minimize +/// the number of grid points evaluated and still hit the maximum likelihood point. +/// This is achieved via a binary search over the grid points. +/// The assumption is that the density is unimodal. If that is not the case, +/// the binary search will not find the maximum and the integral can miss features. +/// +/// # Example +/// +/// ```rust +/// use bio::stats::probs::adaptive_integration::ln_integrate_exp; +/// use bio::stats::probs::{Prob, LogProb}; +/// use statrs::distribution::{Normal, Continuous}; +/// use statrs::statistics::Distribution; +/// use ordered_float::NotNan; +/// use approx::abs_diff_eq; +/// +/// let ndist = Normal::new(0.0, 1.0).unwrap(); +/// +/// let integral = ln_integrate_exp( +/// |x| LogProb::from(Prob(ndist.pdf(*x))), +/// NotNan::new(-1.0).unwrap(), +/// NotNan::new(1.0).unwrap(), +/// NotNan::new(0.01).unwrap() +/// ); +/// abs_diff_eq!(integral.exp(), 0.682, epsilon=0.01); +/// ``` +pub fn ln_integrate_exp( + mut density: F, + min_point: T, + max_point: T, + max_resolution: T, +) -> LogProb +where + T: Copy + + Add + + Sub + + Div + + Div, Output = T> + + Mul + + Into + + From + + Ord + + Debug + + Hash, + F: FnMut(T) -> LogProb, + f64: From, +{ + let mut probs = HashMap::new(); + + let mut grid_point = |point, probs: &mut HashMap<_, _>| { + probs.insert(point, density(point)); + point + }; + let middle_grid_point = |left: T, right: T| (right + left) / NotNan::new(2.0).unwrap(); + // METHOD: + // Step 1: perform binary search for maximum likelihood point + // Remember all points. + let mut left = grid_point(min_point, &mut probs); + let mut right = grid_point(max_point, &mut probs); + let mut first_middle = None; + let mut middle = None; + + while (((right - left) >= max_resolution) && left < right) || middle.is_none() { + middle = Some(grid_point(middle_grid_point(left, right), &mut probs)); + + if first_middle.is_none() { + first_middle = middle; + } + + let left_prob = probs.get(&left).unwrap(); + let right_prob = probs.get(&right).unwrap(); + + if left_prob > right_prob { + // investigate left window more closely + right = middle.unwrap(); + } else { + // investigate right window more closely + left = middle.unwrap(); + } + } + // METHOD: add additional grid point in the initially abandoned arm + if middle < first_middle { + grid_point( + middle_grid_point(first_middle.unwrap(), max_point), + &mut probs, + ); + } else { + grid_point( + middle_grid_point(min_point, first_middle.unwrap()), + &mut probs, + ); + } + // METHOD additionally investigate small interval around the optimum + for point in linspace( + cmp::max( + middle.unwrap() - (max_resolution.into() * 3.0).into(), + min_point, + ) + .into(), + middle.unwrap().into(), + 4, + ) + .take(3) + .chain( + linspace( + middle.unwrap().into(), + cmp::min( + middle.unwrap() + (max_resolution.into() * 3.0).into(), + max_point, + ) + .into(), + 4, + ) + .skip(1), + ) { + grid_point(point.into(), &mut probs); + } + + let sorted_grid_points: Vec = probs.keys().sorted().map(|point| (*point).into()).collect(); + + // METHOD: + // Step 2: integrate over grid points visited during the binary search. + LogProb::ln_trapezoidal_integrate_grid_exp::( + |_, g| *probs.get(&T::from(g)).unwrap(), + &sorted_grid_points, + ) +} diff --git a/src/stats/probs/mod.rs b/src/stats/probs/mod.rs index 851d4f8fe..1895c6efc 100644 --- a/src/stats/probs/mod.rs +++ b/src/stats/probs/mod.rs @@ -6,6 +6,7 @@ //! Handling log-probabilities. Log probabilities are an important tool to deal with probabilities //! in a numerically stable way, in particular when having probabilities close to zero. +pub mod adaptive_integration; pub mod cdf; pub mod errors;