New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Entropy/Mode methods to discrete distributions #1706
base: master
Are you sure you want to change the base?
Conversation
Latest Master
Pull master
@fehiepsi Hello! |
|
||
# Could make this into a function if we need it elsewhere. | ||
qlogq = jnp.where(probs0 == 0.0, 0.0, probs0 * log_probs0) | ||
plogp = jnp.where(probs1 == 0.0, 0.0, probs1 * log_probs1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are those lines needed? When probs0=0, we also have probs0 * log_probs0 = 0. If we need to avoid inf/nan in those log_probs0 and log_probs1, I would suggest using the formula in BernoulliLogits
probs0 = 1 - self.probs
logits = self.logits
return -probs0 * logits + softplus(-logits)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok will test.
# Could make this into a function if we need it elsewhere. | ||
qlogq = jnp.where(q == 0.0, 0.0, q * logq) | ||
plogp = jnp.where(p == 0.0, 0.0, p * logp) | ||
return (-qlogq - plogp) * 1.0 / p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that plogp / p = logp
so you can return -(q / p) * logq - logp
. The trickiest part is to find a robust formula for qlogq / p when p near 0 and near 1.
When p near 1, q near 0, and qlogq = 0 as expected.
When p near 0, q near 1, we need to maintain precision such that logq / p = -1. I'm not sure how to achieve it.
A simpler formular is to rewrite the entropy as -(1/p-1)logq - logp = -logq / p - logits. We arrive at a simpler formula now. But we still get the same issue to approximate logq/p when p near 0. Any idea from you?
|
||
def entropy(self): | ||
log_probs = jax.nn.log_softmax(self.logits) | ||
return -jnp.sum(mul_exp(log_probs, log_probs), axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not using log_probs * self.probs
for simplicity?
Hi @stergiosba, it seems that using import jax
import jax.numpy as jnp
logits = jnp.array([-100., -80., -60., -40., -20., 0., 20., 40., 60.])
logq = -jax.nn.softplus(logits)
logp = -jax.nn.softplus(-logits)
p = jax.scipy.special.expit(logits)
# probs = clam_probs(self.probs)
p_clip = jnp.clip(p, a_min=jnp.finfo(p).tiny)
-(1 - p) * logq / p_clip - logp gives
while the actual values are from decimal import Decimal
for i in logits:
l = Decimal(str(i))
p = 1 / (1 + (-l).exp())
print("entropy", -((1 - p) * (1 - p).ln() + p * p.ln()) / p)
|
Hi @stergiosba, we will release numpyro 0.14 in a few days. Do you want to incorporate this feature into it? |
Solves #1696
Adds the following:
entropy
method for: Bernoulli/Categorical/DiscreteUniform/Geometricmode
property for: Bernoulli/Categorical/Binomial/Poisson/Geometricname
property for all distributions and explanation for NotImplementedError.Current Status: Passes all local tests CPU and GPU (tested on JAX 0.4.23 and CUDA 12.2)
Added dependencies: None