Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Poisson quantile #1791

Open
Joshuaalbert opened this issue Feb 22, 2024 · 0 comments
Open

Add Poisson quantile #1791

Joshuaalbert opened this issue Feb 22, 2024 · 0 comments

Comments

@Joshuaalbert
Copy link
Contributor

Adding the poisson quantile would be useful. As a use case, JAXNS uses quantiles to reparametrise. Below is a bisection approach accurate to rate of 1e4.

from functools import partial

import jax
import numpy as np
import pytest
from jax import numpy as jnp, vmap, lax

from jaxns.internals.types import int_type


def _poisson_quantile_bisection(U, rate, max_iter=15, unroll: bool = True):
    """
    Compute the quantile of the Poisson distribution using bisection.
    
    Args:
        U: the base measure
        rate: the rate of the Poisson distribution
        max_iter: the maximum number of iterations.By default accurate up to rate=1e4.
        unroll: whether to unroll the loop

    Returns:
        the quantile
    """
    # max_iter is set so that error < 1 up to rate of 1e4
    rate = jnp.maximum(jnp.asarray(rate), 1e-5)
    if np.size(rate) > 1:
        raise ValueError("Rate must be a scalar")

    if np.size(U) > 1:
        U_flat = U.ravel()
        x_final, x_results = vmap(lambda u: _poisson_quantile_bisection(u, rate, max_iter, unroll))(U_flat)
        return x_final.reshape(U.shape), x_results.reshape(U.shape + (max_iter,))

    def smooth_cdf(x, rate):
        return lax.igammac(x + 1., rate)

    def fixed_point_update(x, args):
        (a, b, f_a, f_b) = x

        c = 0.5 * (a + b)
        f_c = smooth_cdf(c, rate)

        left = f_c > U
        a1 = jnp.where(left, a, c)
        f_a1 = jnp.where(left, f_a, f_c)
        b1 = jnp.where(left, c, b)
        f_b1 = jnp.where(left, f_c, f_b)

        a2 = a
        f_a2 = f_a
        b2 = b * 2.
        f_b2 = smooth_cdf(b2, rate)

        bounded = f_b >= U  # a already bounds.

        a = jnp.where(bounded, a1, a2)
        b = jnp.where(bounded, b1, b2)
        f_a = jnp.where(bounded, f_a1, f_a2)
        f_b = jnp.where(bounded, f_b1, f_b2)

        new_x = (a, b, f_a, f_b)

        return new_x, 0.5 * (a + b)

    a = jnp.asarray(0.)
    b = jnp.asarray(rate)
    f_a = jnp.asarray(0.)
    f_b = smooth_cdf(b, rate)
    init = (a, b, f_a, f_b)

    # Dummy array to facilitate using scan for a fixed number of iterations
    (a, b, f_a, f_b), x_results = lax.scan(
        fixed_point_update,
        init,
        jnp.arange(max_iter),
        unroll=max_iter if unroll else 1
    )

    c = 0.5 * (a + b)

    return c, x_results


@partial(jax.jit, static_argnames=("unroll",))
def _poisson_quantile(U, rate, unroll: bool = False):
    """
    Compute the quantile of the Poisson distribution using bisection.
    
    Args:
        U: the base measure 
        rate: the rate of the Poisson distribution
        unroll: whether to unroll the loop

    Returns:
        the quantile
    """
    x, _ = _poisson_quantile_bisection(U, rate, unroll=unroll)
    return x.astype(int_type)


@pytest.mark.parametrize("rate, error", (
        [2.0, 1.],
        [10., 1.],
        [100., 1.],
        [1000., 1.],
        [10000., 1.]
)
                         )
def test_poisson_quantile_bisection(rate, error):
    U = jnp.linspace(0., 1. - np.spacing(1.), 1000)
    x, x_results = _poisson_quantile_bisection(U, rate, unroll=False)
    diff_last_two = jnp.abs(x_results[..., -1] - x_results[..., -2])

    # Make sure less than 1 apart
    assert jnp.all(diff_last_two <= error)


@pytest.mark.parametrize("rate", [2.0, 10., 100., 1000., 10000.])
def test_poisson_quantile(rate):
    U = jnp.linspace(0., 1. - np.spacing(1.), 10000)
    x = _poisson_quantile(U, rate)
    assert jnp.all(jnp.isfinite(x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant