Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adaptive integration of density functions using a binary search approach that tries to achieve good resolution around the maximum #486

Merged
merged 4 commits into from Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
146 changes: 146 additions & 0 deletions src/stats/probs/adaptive_integration.rs
@@ -0,0 +1,146 @@
// Copyright 2021-2022 Johannes Köster.
// Licensed under the MIT license (http://opensource.org/licenses/MIT)
// This file may not be copied, modified, or distributed
// except according to those terms.

use std::cmp;
use std::collections::HashMap;
use std::convert::Into;
use std::hash::Hash;
use std::{
fmt::Debug,
ops::{Add, Div, Mul, Sub},
};

use crate::stats::probs::LogProb;
use itertools::Itertools;
use itertools_num::linspace;
use ordered_float::NotNan;

/// Integrate over an interval of type T with a given density function while trying to minimize
/// the number of grid points evaluated and still hit the maximum likelihood point.
/// This is achieved via a binary search over the grid points.
/// The assumption is that the density is unimodal. If that is not the case,
/// the binary search will not find the maximum and the integral can miss features.
///
/// # Example
///
/// ```rust
/// use bio::stats::probs::adaptive_integration::ln_integrate_exp;
/// use bio::stats::probs::{Prob, LogProb};
/// use statrs::distribution::{Normal, Continuous};
/// use statrs::statistics::Distribution;
/// use ordered_float::NotNan;
/// use approx::abs_diff_eq;
///
/// let ndist = Normal::new(0.0, 1.0).unwrap();
///
/// let integral = ln_integrate_exp(
/// |x| LogProb::from(Prob(ndist.pdf(*x))),
/// NotNan::new(-1.0).unwrap(),
/// NotNan::new(1.0).unwrap(),
/// NotNan::new(0.01).unwrap()
/// );
/// abs_diff_eq!(integral.exp(), 0.682, epsilon=0.01);
/// ```
pub fn ln_integrate_exp<T, F>(
mut density: F,
min_point: T,
max_point: T,
max_resolution: T,
) -> LogProb
where
T: Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Div<NotNan<f64>, Output = T>
+ Mul<Output = T>
+ Into<f64>
+ From<f64>
+ Ord
+ Debug
+ Hash,
F: FnMut(T) -> LogProb,
f64: From<T>,
{
let mut probs = HashMap::new();

let mut grid_point = |point, probs: &mut HashMap<_, _>| {
probs.insert(point, density(point));
point
};
let middle_grid_point = |left: T, right: T| (right + left) / NotNan::new(2.0).unwrap();
// METHOD:
// Step 1: perform binary search for maximum likelihood point
// Remember all points.
let mut left = grid_point(min_point, &mut probs);
let mut right = grid_point(max_point, &mut probs);
let mut first_middle = None;
let mut middle = None;

while (((right - left) >= max_resolution) && left < right) || middle.is_none() {
middle = Some(grid_point(middle_grid_point(left, right), &mut probs));

if first_middle.is_none() {
first_middle = middle;
}

let left_prob = probs.get(&left).unwrap();
let right_prob = probs.get(&right).unwrap();

if left_prob > right_prob {
// investigate left window more closely
right = middle.unwrap();
} else {
// investigate right window more closely
left = middle.unwrap();
}
}
// METHOD: add additional grid point in the initially abandoned arm
if middle < first_middle {
grid_point(
middle_grid_point(first_middle.unwrap(), max_point),
&mut probs,
);
} else {
grid_point(
middle_grid_point(min_point, first_middle.unwrap()),
&mut probs,
);
}
// METHOD additionally investigate small interval around the optimum
for point in linspace(
cmp::max(
middle.unwrap() - (max_resolution.into() * 3.0).into(),
min_point,
)
.into(),
middle.unwrap().into(),
4,
)
.take(3)
.chain(
linspace(
middle.unwrap().into(),
cmp::min(
middle.unwrap() + (max_resolution.into() * 3.0).into(),
max_point,
)
.into(),
4,
)
.skip(1),
) {
grid_point(point.into(), &mut probs);
}

let sorted_grid_points: Vec<f64> = probs.keys().sorted().map(|point| (*point).into()).collect();

// METHOD:
// Step 2: integrate over grid points visited during the binary search.
LogProb::ln_trapezoidal_integrate_grid_exp::<f64, _>(
|_, g| *probs.get(&T::from(g)).unwrap(),
&sorted_grid_points,
)
}
1 change: 1 addition & 0 deletions src/stats/probs/mod.rs
Expand Up @@ -6,6 +6,7 @@
//! Handling log-probabilities. Log probabilities are an important tool to deal with probabilities
//! in a numerically stable way, in particular when having probabilities close to zero.

pub mod adaptive_integration;
pub mod cdf;
pub mod errors;

Expand Down