Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributions Entropy Method #1696

Open
stergiosba opened this issue Dec 12, 2023 · 15 comments
Open

Distributions Entropy Method #1696

stergiosba opened this issue Dec 12, 2023 · 15 comments
Labels
enhancement New feature or request

Comments

@stergiosba
Copy link

Hello guys, I come from the Tensorflow Distributions world and was looking for a lightweight alternative and was pleasantly surprised to see that Pyro is available for Jax via your amazing work.

I have implemented the PPO algorithm for some of my DRL problems and inside the loss function the entropy of a Categorical distribution is needed. I saw that the CategoricalLogits class does not have an entropy method contrary to those found in TFP and Distrax (from DeepMind). Is there a different, and possibly, more streamlined way to get it in numpyro without an external function that has the following form:

def entropy(distr: numpyro.distributions.discrete.CategoricalLogits):
        logits = distr.logits
        return -jnp.sum(jax.nn.softmax(logits)*jax.nn.log_softmax(logits))

Is this a design choice? I have implemented an entropy method on the local numpyro I am using for my projects but possible others want this little feature added.

Anyways let me know what you think.

Cheers!

@stergiosba stergiosba changed the title Categorical Distro Entropy Categorical Distribution Entropy Dec 12, 2023
@fehiepsi fehiepsi added the enhancement New feature or request label Dec 13, 2023
@fehiepsi
Copy link
Member

Yeah, it would be great if we have the entropy method. So you can do d.entropy() where d is an logits categorical distribution.

@stergiosba stergiosba changed the title Categorical Distribution Entropy Distributions Entropy Method Dec 13, 2023
@stergiosba
Copy link
Author

Changed topic title since all distributions (or most of them) do not have an entropy method.

@yayami3
Copy link
Contributor

yayami3 commented Dec 18, 2023

@stergiosba
Hi. Are you working on this issue ? Or else I want to do it.

@stergiosba
Copy link
Author

stergiosba commented Dec 18, 2023

I am actively working on it yes. Let's colab if you want @yayami3

@yayami3
Copy link
Contributor

yayami3 commented Dec 18, 2023

@stergiosba
Thanks for the offer. Which distribution are you targeting?
I wrote a draft about the foundational classes and tests here

@stergiosba
Copy link
Author

I am working on discrete ones now. I added entropy as a method and not a property so it matches other python modules like Distrax and TFP.

I have done Categorical and Bernoulli. I double check with Distrax and TFP to get the same results as they do.
Only one comment by looking at your test cases:

@pytest.mark.parametrize(
    "jax_dist, sp_dist, params",
    [
        T(dist.BernoulliProbs, 0.2),
        T(dist.BernoulliProbs, np.array([0.2, 0.7])),
        T(dist.BernoulliLogits, np.array([-1.0, 3.0])),
    ],
)

Make sure you cover edge cases like exploding logits.

For the Bernoulli distribution you used xlogy this automatically handles the aforementioned problem so the following works:

def entropy(self):
    return -xlogy(self.probs, self.probs) - xlog1py(1 - self.probs, -self.probs)

But I wanted to make the explicit check and did this for example:

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs0 = _to_probs_bernoulli(-1.0 * self.logits)
    probs1 = self.probs
    log_probs0 = -jax.nn.softplus(self.logits)
    log_probs1 = -jax.nn.softplus(-1.0 * self.logits)
    
    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(probs0 == 0.0, 0.0, probs0 * log_probs0)
    plogp = jnp.where(probs1 == 0.0, 0.0, probs1 * log_probs1)
    return -qlogq - plogp

I compared the performance of both solutions and is the same. Also for some reason xlogy cuts off at the 8th decimal point but that is minor.

I don't know which style is better. Maybe @fehiepsi can give his take on this.

@stergiosba
Copy link
Author

I am also adding a mode property for the distributions.

@yayami3
Copy link
Contributor

yayami3 commented Dec 19, 2023

@stergiosba
Thanks for your comment !
I think it's a good idea, but there are other modules using xlogy as well, and it feels like there is a lack of consistency.
I felt that it was more important to make people aware of the purpose and effects of xlogy.
Let's just wait for @fehiepsi anyway.

@fehiepsi
Copy link
Member

I think you can clip y and use xlogy. I remember than grad needs to be computed correctly at the extreme points. I don't have strong opinion on the style though.

@stergiosba
Copy link
Author

stergiosba commented Dec 19, 2023

Great catch there @fehiepsi

There is an issue with the gradients when using the xlogy.
I set up a toy problem to test gradients and the xlogy implementation failed at the extreme case where p =1 by returned nan.

I was able to fix the nan by adding lax.stop_gradient for the case where p=1 as:

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs = lax.stop_gradient(self.probs)
    return -xlogy(probs, probs) - xlog1py(1 - probs, -probs)

Same with clipping and using xlogy:

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs = lax.stop_gradient(self.probs)
    
    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(probs == 0.0, 0.0, xlog1py(1.0-probs, -probs))
    plogp = jnp.where(probs == 1.0, 0.0, xlogy(probs, probs))
    return -qlogq - plogp

Just for the record the first entropy calculation I provided was based on Distrax's code and it had no problems with gradients "out of the box".

But we can go forward with the xlogy function as it works with the addition of stop gradients.

@fehiepsi
Copy link
Member

fehiepsi commented Dec 19, 2023

I think it is better to do: probs_positive = clip(probs, a_min=tiny) and compute xlogy(probs, probs_positive). similar to probs_less_than_one. We need grad of the entropy.

@stergiosba
Copy link
Author

Yeah I was blind, I see the issue. This is the clipping:

def entropy(self, eps=1e-9):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    probs = jnp.clip(self.probs, eps, 1.0 - eps)
    return -xlogy(probs, probs) - xlog1py(1.0 - probs, -probs)

Clipping works for the gradients but inherently has errors. For example we fail to pass the testcase with big negative logit.

The following, although not the most beautiful, works for everything so I vote to go with it.

def entropy(self):
    """Calculates the entropy of the Bernoulli distribution with probability p.
    H(p,q)=-qlog(q)-plog(p) where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Bernoulli distribution.
    """
    q = _to_probs_bernoulli(-1.0 * self.logits)
    p = self.probs
    logq = -jax.nn.softplus(self.logits)
    logp = -jax.nn.softplus(-1.0 * self.logits)
    
    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(q == 0.0, 0.0, q * logq)
    plogp = jnp.where(p == 0.0, 0.0, p * logp)
    return -qlogq - plogp

@stergiosba
Copy link
Author

stergiosba commented Dec 20, 2023

Ideas about what to return when logits are very negative in a Geometric distribution.

As you can see from the code below we need to divide with p and when logits are very negative p=sigmoid(logit)=0.

TFP and PyTorch return nan in this case and DIstrax does not have a Geometric distribution.

def entropy(self):
    """Calculates the entropy of the Geometric distribution with probability p.
    H(p,q)=[-qlog(q)-plog(p)]*1/p where q=1-p.
    With extra care for p=0 and p=1.

    Returns:
        entropy: The entropy of the Geometric distribution.
    """
    q = _to_probs_bernoulli(-1.0 * self.logits)
    p = self.probs
    logq = -jax.nn.softplus(self.logits)
    logp = -jax.nn.softplus(-1.0 * self.logits)

    # Could make this into a function if we need it elsewhere.
    qlogq = jnp.where(q == 0.0, 0.0, q * logq)
    plogp = jnp.where(p == 0.0, 0.0, p * logp)
    return (-qlogq - plogp) * 1.0 / p

@fehiepsi
Copy link
Member

fehiepsi commented Dec 20, 2023

You can divide implicitly (rather than directly). e.g. I think you can use ( I have not checked yet)

  • for bernoulli: (1 - probs) * logits - log_sigmoid(logits)
  • for geometric: -(1 + jnp.exp(-logits)) * log_sigmoid(-logits) - log_sigmoid(logits)

Edit: ignore me, exp(-logits) can be very large

@stergiosba
Copy link
Author

stergiosba commented Dec 20, 2023

Ok I will add some tests for the Probs versions of the distributions and submit a PR for the discrete distributions and you can review it there. Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants