-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of mean in categorical #1718
base: main
Are you sure you want to change the base?
Conversation
Implementation of the mean method in the Categorical distribution.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this; a few comments at the change site.
@@ -333,6 +333,10 @@ def _entropy(self): | |||
_mul_exp(log_probs, log_probs), | |||
axis=-1) | |||
|
|||
def _mean(self): | |||
probs = self.probs_parameter() | |||
return tf.reduce_sum(tf.range(self._num_categories(probs),dtype=probs.dtype) * probs, axis=-1) / tf.reduce_sum(probs, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
we should not divide by sum of probs here. if they don't sum to one, and
validate_args
is true, this is an error.validate_args
is false by default to prevent spending unnecessary compute. -
a more numerically stable implementation would use logits along with tfp.math.reduce_logmeanexp with the weights arg. the current one is ok, but suboptimal.
-
some unit tests should be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, meant weighted_logsumexp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we always expect that the provided probs sum up to 1 for this method, and thus add an assertion that this is the case before computing the mean? Because other methods in the Categorical work without the sum of the probs necessarily being 1 (with validate_args=False).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the previous implementation is definitely possible with logits and the reduce_logsum_exp function:
logits = self.logits_parameter()
return tf.math.exp(reduce_weighted_logsumexp(logits,w=tf.range(self._num_categories(logits),dtype=logits.dtype),axis=-1))
I tested it and it produces the same results, so I can replace my code with this (more stable) implementation. The only issue is that I'm still waiting for approval from a maintainer to run the workflow tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we always expect that the provided probs sum up to 1 for this method, and thus add an assertion that this is the case before computing the mean? Because other methods in the Categorical work without the sum of the probs necessarily being 1 (with validate_args=False).
There are already such assertions (when validate_args=True) in the execution path of all these methods. Look at _parameter_control_dependencies
in this file, as well as most other Distribution
subclasses in TFP, to see which ones there are. These are triggered by the base Distribution
class when any public API point is invoked (eg, dist.log_prob, dist.sample, dist.mean, etc...again, only if validate_args is True 🙂)
The computation of the mean is done with logits (instead of probs) to make the implementation more stable numerically.
@@ -30,6 +30,7 @@ | |||
from tensorflow_probability.python.internal import samplers | |||
from tensorflow_probability.python.internal import tensor_util | |||
from tensorflow_probability.python.internal import tensorshape_util | |||
from tensorflow_probability.python.math import reduce_weighted_logsumexp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please import as
from tensorflow_probability.python.math.generic import reduce_weighted_logsumexp
and add this to the "categorical" deps
list in the adjacent BUILD file:
"//tensorflow_probability/python/math:generic"
@@ -333,6 +334,13 @@ def _entropy(self): | |||
_mul_exp(log_probs, log_probs), | |||
axis=-1) | |||
|
|||
def _mean(self): | |||
#probs = self.probs_parameter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new impl looks good, thanks! please
- remove the commented out lines
- add a test in categorical_test.py. you can look at other categorical tests and maybe normal_test.py for a hint of how these should look. think of edge cases, like some zero prob categories, etc.
#return tf.reduce_sum(tf.range(self._num_categories(probs),dtype=probs.dtype) * probs, axis=-1) / tf.reduce_sum(probs, axis=-1) | ||
# Implement with logits to improve numerical stability | ||
logits = self.logits_parameter() | ||
return tf.math.exp(reduce_weighted_logsumexp(logits,w=tf.range(self._num_categories(logits),dtype=logits.dtype),axis=-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please format as
return tf.math.exp(
reduce_weighted_logsumexp(
logits,
w=tf.range(self._num_categories(logits), dtype=logits.dtype),
axis=-1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these, my last commit now includes all the requested changes.
@@ -333,6 +333,10 @@ def _entropy(self): | |||
_mul_exp(log_probs, log_probs), | |||
axis=-1) | |||
|
|||
def _mean(self): | |||
probs = self.probs_parameter() | |||
return tf.reduce_sum(tf.range(self._num_categories(probs),dtype=probs.dtype) * probs, axis=-1) / tf.reduce_sum(probs, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we always expect that the provided probs sum up to 1 for this method, and thus add an assertion that this is the case before computing the mean? Because other methods in the Categorical work without the sum of the probs necessarily being 1 (with validate_args=False).
There are already such assertions (when validate_args=True) in the execution path of all these methods. Look at _parameter_control_dependencies
in this file, as well as most other Distribution
subclasses in TFP, to see which ones there are. These are triggered by the base Distribution
class when any public API point is invoked (eg, dist.log_prob, dist.sample, dist.mean, etc...again, only if validate_args is True 🙂)
Removed commented lines and reformated
self.assertAllEqual((1,), dist.mean().shape) | ||
# Expected mean will be the same as in a Multinomial with n = 1 | ||
expected_means = stats.multinomial.mean(n=1, p=p).argmax(axis=-1) | ||
self.assertAllClose(expected_means, self.evaluate(binom.mean())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change binom
to categorical
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, corrected it...
Defined a _mean method for implementing the mean of the Categorical distribution, following a previous PR (#1411).