You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adding the poisson quantile would be useful. As a use case, JAXNS uses quantiles to reparametrise. Below is a bisection approach accurate to rate of 1e4.
fromfunctoolsimportpartialimportjaximportnumpyasnpimportpytestfromjaximportnumpyasjnp, vmap, laxfromjaxns.internals.typesimportint_typedef_poisson_quantile_bisection(U, rate, max_iter=15, unroll: bool=True):
""" Compute the quantile of the Poisson distribution using bisection. Args: U: the base measure rate: the rate of the Poisson distribution max_iter: the maximum number of iterations.By default accurate up to rate=1e4. unroll: whether to unroll the loop Returns: the quantile """# max_iter is set so that error < 1 up to rate of 1e4rate=jnp.maximum(jnp.asarray(rate), 1e-5)
ifnp.size(rate) >1:
raiseValueError("Rate must be a scalar")
ifnp.size(U) >1:
U_flat=U.ravel()
x_final, x_results=vmap(lambdau: _poisson_quantile_bisection(u, rate, max_iter, unroll))(U_flat)
returnx_final.reshape(U.shape), x_results.reshape(U.shape+ (max_iter,))
defsmooth_cdf(x, rate):
returnlax.igammac(x+1., rate)
deffixed_point_update(x, args):
(a, b, f_a, f_b) =xc=0.5* (a+b)
f_c=smooth_cdf(c, rate)
left=f_c>Ua1=jnp.where(left, a, c)
f_a1=jnp.where(left, f_a, f_c)
b1=jnp.where(left, c, b)
f_b1=jnp.where(left, f_c, f_b)
a2=af_a2=f_ab2=b*2.f_b2=smooth_cdf(b2, rate)
bounded=f_b>=U# a already bounds.a=jnp.where(bounded, a1, a2)
b=jnp.where(bounded, b1, b2)
f_a=jnp.where(bounded, f_a1, f_a2)
f_b=jnp.where(bounded, f_b1, f_b2)
new_x= (a, b, f_a, f_b)
returnnew_x, 0.5* (a+b)
a=jnp.asarray(0.)
b=jnp.asarray(rate)
f_a=jnp.asarray(0.)
f_b=smooth_cdf(b, rate)
init= (a, b, f_a, f_b)
# Dummy array to facilitate using scan for a fixed number of iterations
(a, b, f_a, f_b), x_results=lax.scan(
fixed_point_update,
init,
jnp.arange(max_iter),
unroll=max_iterifunrollelse1
)
c=0.5* (a+b)
returnc, x_results@partial(jax.jit, static_argnames=("unroll",))def_poisson_quantile(U, rate, unroll: bool=False):
""" Compute the quantile of the Poisson distribution using bisection. Args: U: the base measure rate: the rate of the Poisson distribution unroll: whether to unroll the loop Returns: the quantile """x, _=_poisson_quantile_bisection(U, rate, unroll=unroll)
returnx.astype(int_type)
@pytest.mark.parametrize("rate, error", ( [2.0, 1.], [10., 1.], [100., 1.], [1000., 1.], [10000., 1.]) )deftest_poisson_quantile_bisection(rate, error):
U=jnp.linspace(0., 1.-np.spacing(1.), 1000)
x, x_results=_poisson_quantile_bisection(U, rate, unroll=False)
diff_last_two=jnp.abs(x_results[..., -1] -x_results[..., -2])
# Make sure less than 1 apartassertjnp.all(diff_last_two<=error)
@pytest.mark.parametrize("rate", [2.0, 10., 100., 1000., 10000.])deftest_poisson_quantile(rate):
U=jnp.linspace(0., 1.-np.spacing(1.), 10000)
x=_poisson_quantile(U, rate)
assertjnp.all(jnp.isfinite(x))
The text was updated successfully, but these errors were encountered:
Adding the poisson quantile would be useful. As a use case, JAXNS uses quantiles to reparametrise. Below is a bisection approach accurate to rate of 1e4.
The text was updated successfully, but these errors were encountered: