Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of mean in categorical #1718

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

fotisdr
Copy link

@fotisdr fotisdr commented May 8, 2023

NotImplementedError: mean is not implemented: Categorical

Defined a _mean method for implementing the mean of the Categorical distribution, following a previous PR (#1411).

Implementation of the mean method in the Categorical distribution.
@google-cla
Copy link

google-cla bot commented May 8, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Member

@csuter csuter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this; a few comments at the change site.

@@ -333,6 +333,10 @@ def _entropy(self):
_mul_exp(log_probs, log_probs),
axis=-1)

def _mean(self):
probs = self.probs_parameter()
return tf.reduce_sum(tf.range(self._num_categories(probs),dtype=probs.dtype) * probs, axis=-1) / tf.reduce_sum(probs, axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. we should not divide by sum of probs here. if they don't sum to one, and validate_args is true, this is an error. validate_args is false by default to prevent spending unnecessary compute.

  2. a more numerically stable implementation would use logits along with tfp.math.reduce_logmeanexp with the weights arg. the current one is ok, but suboptimal.

  3. some unit tests should be added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, meant weighted_logsumexp

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we always expect that the provided probs sum up to 1 for this method, and thus add an assertion that this is the case before computing the mean? Because other methods in the Categorical work without the sum of the probs necessarily being 1 (with validate_args=False).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the previous implementation is definitely possible with logits and the reduce_logsum_exp function:

logits = self.logits_parameter()
return tf.math.exp(reduce_weighted_logsumexp(logits,w=tf.range(self._num_categories(logits),dtype=logits.dtype),axis=-1))

I tested it and it produces the same results, so I can replace my code with this (more stable) implementation. The only issue is that I'm still waiting for approval from a maintainer to run the workflow tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we always expect that the provided probs sum up to 1 for this method, and thus add an assertion that this is the case before computing the mean? Because other methods in the Categorical work without the sum of the probs necessarily being 1 (with validate_args=False).

There are already such assertions (when validate_args=True) in the execution path of all these methods. Look at _parameter_control_dependencies in this file, as well as most other Distribution subclasses in TFP, to see which ones there are. These are triggered by the base Distribution class when any public API point is invoked (eg, dist.log_prob, dist.sample, dist.mean, etc...again, only if validate_args is True 🙂)

The computation of the mean is done with logits (instead of probs) to make the implementation more stable numerically.
@@ -30,6 +30,7 @@
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.math import reduce_weighted_logsumexp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please import as

from tensorflow_probability.python.math.generic import reduce_weighted_logsumexp

and add this to the "categorical" deps list in the adjacent BUILD file:

"//tensorflow_probability/python/math:generic"

@@ -333,6 +334,13 @@ def _entropy(self):
_mul_exp(log_probs, log_probs),
axis=-1)

def _mean(self):
#probs = self.probs_parameter()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new impl looks good, thanks! please

  1. remove the commented out lines
  2. add a test in categorical_test.py. you can look at other categorical tests and maybe normal_test.py for a hint of how these should look. think of edge cases, like some zero prob categories, etc.

#return tf.reduce_sum(tf.range(self._num_categories(probs),dtype=probs.dtype) * probs, axis=-1) / tf.reduce_sum(probs, axis=-1)
# Implement with logits to improve numerical stability
logits = self.logits_parameter()
return tf.math.exp(reduce_weighted_logsumexp(logits,w=tf.range(self._num_categories(logits),dtype=logits.dtype),axis=-1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please format as

return tf.math.exp(
    reduce_weighted_logsumexp(
        logits,
        w=tf.range(self._num_categories(logits), dtype=logits.dtype),
        axis=-1))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these, my last commit now includes all the requested changes.

@@ -333,6 +333,10 @@ def _entropy(self):
_mul_exp(log_probs, log_probs),
axis=-1)

def _mean(self):
probs = self.probs_parameter()
return tf.reduce_sum(tf.range(self._num_categories(probs),dtype=probs.dtype) * probs, axis=-1) / tf.reduce_sum(probs, axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we always expect that the provided probs sum up to 1 for this method, and thus add an assertion that this is the case before computing the mean? Because other methods in the Categorical work without the sum of the probs necessarily being 1 (with validate_args=False).

There are already such assertions (when validate_args=True) in the execution path of all these methods. Look at _parameter_control_dependencies in this file, as well as most other Distribution subclasses in TFP, to see which ones there are. These are triggered by the base Distribution class when any public API point is invoked (eg, dist.log_prob, dist.sample, dist.mean, etc...again, only if validate_args is True 🙂)

self.assertAllEqual((1,), dist.mean().shape)
# Expected mean will be the same as in a Multinomial with n = 1
expected_means = stats.multinomial.mean(n=1, p=p).argmax(axis=-1)
self.assertAllClose(expected_means, self.evaluate(binom.mean()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change binom to categorical :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, corrected it...

@fotisdr fotisdr requested a review from csuter September 15, 2023 07:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants