Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Entropy/Mode methods to discrete distributions #1706

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

stergiosba
Copy link

Solves #1696
Adds the following:

  • entropy method for: Bernoulli/Categorical/DiscreteUniform/Geometric
  • mode property for: Bernoulli/Categorical/Binomial/Poisson/Geometric
  • name property for all distributions and explanation for NotImplementedError.

Current Status: Passes all local tests CPU and GPU (tested on JAX 0.4.23 and CUDA 12.2)
Added dependencies: None

@stergiosba
Copy link
Author

@fehiepsi Hello!
Whenever you want and you are ready take a look.
Thank you!

numpyro/distributions/discrete.py Show resolved Hide resolved

# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(probs0 == 0.0, 0.0, probs0 * log_probs0)
plogp = jnp.where(probs1 == 0.0, 0.0, probs1 * log_probs1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are those lines needed? When probs0=0, we also have probs0 * log_probs0 = 0. If we need to avoid inf/nan in those log_probs0 and log_probs1, I would suggest using the formula in BernoulliLogits

probs0 = 1 - self.probs
logits = self.logits
return -probs0 * logits + softplus(-logits)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok will test.

numpyro/distributions/discrete.py Show resolved Hide resolved
numpyro/distributions/distribution.py Show resolved Hide resolved
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(q == 0.0, 0.0, q * logq)
plogp = jnp.where(p == 0.0, 0.0, p * logp)
return (-qlogq - plogp) * 1.0 / p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that plogp / p = logp so you can return -(q / p) * logq - logp. The trickiest part is to find a robust formula for qlogq / p when p near 0 and near 1.

When p near 1, q near 0, and qlogq = 0 as expected.

When p near 0, q near 1, we need to maintain precision such that logq / p = -1. I'm not sure how to achieve it.

A simpler formular is to rewrite the entropy as -(1/p-1)logq - logp = -logq / p - logits. We arrive at a simpler formula now. But we still get the same issue to approximate logq/p when p near 0. Any idea from you?


def entropy(self):
log_probs = jax.nn.log_softmax(self.logits)
return -jnp.sum(mul_exp(log_probs, log_probs), axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using log_probs * self.probs for simplicity?

@fehiepsi
Copy link
Member

fehiepsi commented Jan 1, 2024

Hi @stergiosba, it seems that using -qlogq / p - logp is pretty stable

import jax
import jax.numpy as jnp

logits = jnp.array([-100., -80., -60., -40., -20., 0., 20., 40., 60.])
logq = -jax.nn.softplus(logits)
logp = -jax.nn.softplus(-logits)
p = jax.scipy.special.expit(logits)
# probs = clam_probs(self.probs)
p_clip = jnp.clip(p, a_min=jnp.finfo(p).tiny)
-(1 - p) * logq / p_clip - logp

gives

Array([1.0000000e+02, 8.1000000e+01, 6.1000000e+01, 4.1000000e+01,
       2.1000000e+01, 1.3862944e+00, 2.0611537e-09, 4.2483541e-18,
       8.7565109e-27], dtype=float32)

while the actual values are

from decimal import Decimal
for i in logits:
    l = Decimal(str(i))
    p = 1 / (1 + (-l).exp())
    print("entropy", -((1 - p) * (1 - p).ln() + p * p.ln()) / p)
entropy 1E+2
entropy 80.00000000000000000000000002
entropy 61.00496650303780216962340234
entropy 41.00000000000197983295132689
entropy 21.00000000103057681049624664
entropy 1.386294361119890618834464243
entropy 4.328422607333389151485388142E-8
entropy 1.741825244552915897262840401E-16
entropy 5.487531564015271267712575141E-25

@fehiepsi
Copy link
Member

Hi @stergiosba, we will release numpyro 0.14 in a few days. Do you want to incorporate this feature into it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants