New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CategoricalProbs Sample Discrepancies #1702
Comments
Thanks for the detailed explanation, @stergiosba! It is expected that samples from different libraries will be generated in different orders. We will need to think more about whether to make |
Through the prism of controlled randomness in JAX, it sounds more pleasing to have consistent samples with other libraries. I do understand what you mean but that does not change the fact that the used sampling function is a bit off the mark. Essentially we asked for the following distribution: probs = jnp.array([0.21, 0.25, 0.54]) But looking at the output histograms we got this (in probs = jnp.array([0.19, 0.24, 0.57]) Let me know what you think. |
1000 samples is far too few samples to make some claim about the incorrectness of the sampler. the sampler provides i.i.d. samples; it does not e.g. provide stratified samples that semi-exactly reproduce marginal statistics. trying to ensure that numpyro samplers exactly produce the same samples as in some other library is a pointless exercise. |
I understand what you mean. I am only skeptical about the situation where a small number of samples (100~1000) is needed. In this case the discrepancy is the biggest and for 100 samples you get the following: When the number of samples is small numpyro is way off. In 1 million samples things are more ironed out: Finally, I have no idea where the "is a pointless exercise" comment comes from. Numpyro is a tool that is bound to be compared with other similar tools, if you believe an effort for correctness it is not necessary...well we just don't agree (and that is ok). |
As these are pseudorandom random numbers, the question of correctness is if they produce (approximately) iid samples from the target distribution, not if they generate the same samples. As a simple example, consider an iid sample of size numpyro should certainly be compared with other tools, but, provided I understood @martinjankowiak correctly, the exercise is not fruitful here because looking at the order of samples is not the right question to ask. As an aside, even the C++ standard library random number implementation may given different results for the same seed on different operating systems. |
The following explains the issue.
And here is a histogram of the samples.
I can see this being kind of fast as without
jit
I get the followingBut with
jit
all of the advantage is practically gone:Most importantly what the histograms do not reveal is the order with which samples are sampled.
I include the following figures that show that TFP and Distrax are producing the same samples (x=y line) everytime.
While TFP and Numpyro are not consistently producing the same sample.
This does not happen with the Logits versions of the distributions thus definitely the culprit is the function:
numpyro/numpyro/distributions/util.py
Line 197 in fb7a029
Proposed solution is to change the sample function of the
CategoricalProbs
class and use:The text was updated successfully, but these errors were encountered: