Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CategoricalProbs Sample Discrepancies #1702

Open
stergiosba opened this issue Dec 21, 2023 · 5 comments
Open

CategoricalProbs Sample Discrepancies #1702

stergiosba opened this issue Dec 21, 2023 · 5 comments

Comments

@stergiosba
Copy link

The following explains the issue.

import distrax
import numpyro as pyro
import jax.numpy as jnp
import jax.random as jrandom
import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt

key = jrandom.PRNGKey(0)
probs = jnp.array([0.21, 0.25, 0.54])

d_tfp = tfp.distributions.Categorical(probs=probs)
d_pyro = pyro.distributions.Categorical(probs=probs)
d_distx = distrax.Categorical(probs=probs)

PYRO = []
TFP = []
DISTX = []

n_s = 1000
for _ in range(n_s):
    key, subkey = jrandom.split(key)
    PYRO.append(d_pyro.sample(subkey))
    TFP.append(d_tfp.sample(seed=subkey))
    DISTX.append(d_distx.sample(seed=subkey))

PYRO = jnp.array(PYRO)
TFP = jnp.array(TFP)
DISTX = jnp.array(DISTX)
x = jnp.arange(len(probs) + 1)

plt.hist([PYRO, TFP, DISTX], bins=x)
plt.grid()
plt.title("Samples of X")
plt.xlabel("Sample value of X")
plt.ylabel("Frequency")
plt.legend(["Experiment 1 (loop)", "Experiment 2 (stacked)"])
plt.xticks(x - 0.5, x)
plt.show()

And here is a histogram of the samples.

image

I can see this being kind of fast as without jit I get the following

Numpyro: 255 µs ± 5.25 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
TFP: 741 µs ± 7.94 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Distrax: 219 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

But with jit all of the advantage is practically gone:

Numpyro: 2.54 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
TFP: 2.26 µs ± 130 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Distrax: 2.2 µs ± 57.9 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Most importantly what the histograms do not reveal is the order with which samples are sampled.
I include the following figures that show that TFP and Distrax are producing the same samples (x=y line) everytime.

image

While TFP and Numpyro are not consistently producing the same sample.
image

This does not happen with the Logits versions of the distributions thus definitely the culprit is the function:

def _categorical(key, p, shape):

Proposed solution is to change the sample function of the CategoricalProbs class and use:

def sample(self, key, sample_shape=()):
    assert is_prng_key(key)
    samples = jax.random.categorical(key=key, logits=self.logits, axis=-1,
                            shape=sample_shape + self.batch_shape)
    return samples
@fehiepsi
Copy link
Member

Thanks for the detailed explanation, @stergiosba! It is expected that samples from different libraries will be generated in different orders. We will need to think more about whether to make CategoricalProbs and CategoricalLogits generate consistent samples.

@stergiosba
Copy link
Author

Through the prism of controlled randomness in JAX, it sounds more pleasing to have consistent samples with other libraries. I do understand what you mean but that does not change the fact that the used sampling function is a bit off the mark.

Essentially we asked for the following distribution:

probs = jnp.array([0.21, 0.25, 0.54]) 

But looking at the output histograms we got this (in 1000 samples).

probs = jnp.array([0.19, 0.24, 0.57]) 

Let me know what you think.

@martinjankowiak
Copy link
Collaborator

1000 samples is far too few samples to make some claim about the incorrectness of the sampler. the sampler provides i.i.d. samples; it does not e.g. provide stratified samples that semi-exactly reproduce marginal statistics.

trying to ensure that numpyro samplers exactly produce the same samples as in some other library is a pointless exercise.

@stergiosba
Copy link
Author

I understand what you mean. I am only skeptical about the situation where a small number of samples (100~1000) is needed. In this case the discrepancy is the biggest and for 100 samples you get the following:

image

When the number of samples is small numpyro is way off.

In 1 million samples things are more ironed out:

image

Finally, I have no idea where the "is a pointless exercise" comment comes from. Numpyro is a tool that is bound to be compared with other similar tools, if you believe an effort for correctness it is not necessary...well we just don't agree (and that is ok).

@tillahoffmann
Copy link
Contributor

Finally, I have no idea where the "is a pointless exercise" comment comes from. Numpyro is a tool that is bound to be compared with other similar tools, if you believe an effort for correctness it is not necessary...well we just don't agree (and that is ok).

As these are pseudorandom random numbers, the question of correctness is if they produce (approximately) iid samples from the target distribution, not if they generate the same samples. As a simple example, consider an iid sample of size n. Shuffling that sample is still a sample from the target distribution.

numpyro should certainly be compared with other tools, but, provided I understood @martinjankowiak correctly, the exercise is not fruitful here because looking at the order of samples is not the right question to ask.

As an aside, even the C++ standard library random number implementation may given different results for the same seed on different operating systems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants