Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] InverseGamma CDF #3291

Open
treigerm opened this issue Nov 7, 2023 · 0 comments
Open

[bug] InverseGamma CDF #3291

treigerm opened this issue Nov 7, 2023 · 0 comments
Labels

Comments

@treigerm
Copy link

treigerm commented Nov 7, 2023

Issue Description

Currently the cdf function for InverseGamma returns obviously wrong values. The current implementation returns a monotonically decreasing function which clearly cannot be right.

On closer inspection, currently the cdf(x) function actually return the value 1 - cdf(x). I think this is due to an issue in that for current releases of PyTorch the PowerTransform has a hard-coded sign of +1 (offending line: https://github.com/pytorch/pytorch/blob/7bcf7da3a268b435777fe87c7794c382f444e86d/torch/distributions/transforms.py#L567).

There is a recent PR which has been merged which should actually fix the issue. To confirm this I ran my code on the nightly PyTorch release and the issue indeed goes away.

Environment

Pyro version: 1.8.6
PyTorch version:

  • contains bug: 2.1.0+cu121
  • nightly version that fixes bug: 2.2.0.dev20231107+cpu

OS: Ubuntu 22.04.2 LTS
Python version: 3.10.12

Code Snippet

import pyro.distributions as dist
import matplotlib.pyplot as plt
import torch

xs = torch.linspace(0, 3, 200)
cdfs = dist.InverseGamma(1.0, 1.0).cdf(xs)
# Visual inspection of CDF. See output below.
plt.plot(xs, cdfs)

# gammaincc allows computing the analytic form of the InverseGamma cdf.
# This evaluate to True in version '2.1.0+cu121' but False in '2.2.0.dev20231107+cpu'. 
torch.allclose(cdfs, 1 - torch.special.gammaincc(torch.ones(200), torch.ones(200) / xs))
# If the cdfs would be correct the following statement should evaluate to True instead.
torch.allclose(cdfs, torch.special.gammaincc(torch.ones(200), torch.ones(200) / xs))

cdfs

@fritzo fritzo added the bug label Nov 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants