Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wishart distribution. #1779

Merged
merged 11 commits into from May 13, 2024
Merged

Add Wishart distribution. #1779

merged 11 commits into from May 13, 2024

Conversation

tillahoffmann
Copy link
Contributor

This PR adds the Wishart distribution for covariance matrices. It should probably be rebased and reviewed after #1778 is merged.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tillahoffmann! It is great to have this distribution. I left initial comments below.

numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Show resolved Hide resolved
numpyro/distributions/util.py Show resolved Hide resolved
test/test_distributions.py Show resolved Hide resolved
test/test_transforms.py Show resolved Hide resolved
numpyro/distributions/continuous.py Show resolved Hide resolved
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, @tillahoffmann!

numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @tillahoffmann! Just have a comment at the computation of slogdet.

batch_shape, event_shape = loc[:-1], loc[-1:]
for matrix in [covariance_matrix, precision_matrix, scale_tril]:
if matrix is not None:
batch_shape = lax.broadcast_shapes(batch_shape, matrix[:-2])
event_shape = lax.broadcast_shapes(event_shape, matrix[-1:])
return batch_shape, event_shape
return batch_shape, event_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this indent intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, assert_one_of ensures that exactly one of the matrices is available, and returning early avoids looping over precision_matrix and scale_tril if covariance_matrix is given, for example. Either with/without the indent should work though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks!

numpyro/distributions/continuous.py Show resolved Hide resolved
numpyro/distributions/continuous.py Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
numpyro/distributions/continuous.py Outdated Show resolved Hide resolved
@fehiepsi fehiepsi merged commit 5eb134d into pyro-ppl:master May 13, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants