Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dirichlet distribution sampling issue when jit_compile=True #1789

Open
LorenzoRimella opened this issue Feb 20, 2024 · 1 comment
Open

Dirichlet distribution sampling issue when jit_compile=True #1789

LorenzoRimella opened this issue Feb 20, 2024 · 1 comment

Comments

@LorenzoRimella
Copy link

LorenzoRimella commented Feb 20, 2024

It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

dirichlet_lambda = tf.convert_to_tensor([2., 5., 0., 10., 10., 12., 10., 10., 1., 1.], dtype = tf.float32)
seed_s2 = tf.convert_to_tensor([-1012227931,  -757448172], dtype = tf.int32)
seed_s3 = tf.convert_to_tensor([-1012227931,  -757448170], dtype = tf.int32)

@tf.function(jit_compile = True)
def jitwhat(concentration, seed):
    theta_j_k = tfp.distributions.Dirichlet(concentration = concentration).sample((13, 10), seed = seed) #.sample(seed = seed_s2) #

    return theta_j_k

foo = jitwhat(dirichlet_lambda, seed_s2)
np.where(np.isnan(foo))

Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.

@chrism0dwk
Copy link

Verified as a potential bug. Colab here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants