Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subsampling in some autoguides produces parameters with wrong shapes #3286

Open
gui11aume opened this issue Oct 24, 2023 · 4 comments
Open
Labels

Comments

@gui11aume
Copy link

gui11aume commented Oct 24, 2023

Issue Description

Auto guides need to create parameters in the background. The shape of those parameters is determined by the plates in the model. When plates are subsampled, the parameters should have the dimension of the full plate, not the subsampled plate. This is the case for some auto guides, but for AutoDiscreteParallel the shape of the parameters is wrong.

Environment

  • Linux Ubuntu 22.04, Python 3.8
  • PyTorch 2.0.1.
  • Pyro version 1.8.6.

Code Snippet

The code below shows the difference in behavior between AutoNormal and AutoDiscreteParallel. In both cases, the model creates a plate of size 20 and subsamples it to size 3. Upon gathering the parameters, AutoNormal produces parameters with 20 rows, whereas AutoDiscreteParallel produces parameters with 3 rows.

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Normal(0, 1)) 

guide = pyro.infer.autoguide.AutoNormal(model)

elbo = pyro.infer.Trace_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)

print(pyro.param("AutoNormal.locs.x").shape)
# torch.Size([20])


def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Categorical(torch.ones(1)))

guide = pyro.infer.autoguide.AutoDiscreteParallel(model)

elbo = pyro.infer.TraceEnum_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)


print(pyro.param("AutoDiscreteParallel.x_probs").shape)
# torch.Size([3, 1])

I believe that the issue is in the functions _setup_prototype in pyro/infer/autoguide/guides.py. Below is the code from AutoNormal (see here).

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

There is no equivalent in the _setup_prototype function of AutoDiscreteParallel (see here).

I will work on a pull request to fix this. I would like to also create some additional tests for this and other cases, but I am not too sure where to start. Any help would be appreciated.

@fritzo fritzo added the bug label Oct 24, 2023
@gui11aume
Copy link
Author

Hi @fritzo! I had a closer look at the issue and it's a little more complicated than I thought... Are there already some tests for the creation of parameters in auto guides?

@gui11aume
Copy link
Author

I think that the same phenomenon happens for AutoLowRankMultivariateNormal.

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Normal(0, 1)) 

guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)

elbo = pyro.infer.Trace_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)

print(pyro.param("AutoLowRankMultivariateNormal.loc").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.scale").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.cov_factor").shape)
# torch.Size([3,2])

The parameters should have 20 rows but they have 3.

Following the suggestion of the doc, we can initialize the parameters with pyro.param(...) before calling the guide, hoping to get the correct dimensions. However this fails because Pyro expects the number of rows to be 3 (if you initialize the parameters with 3 rows the code runs fine).

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Normal(0, 1)) 

pyro.param("AutoLowRankMultivariateNormal.loc", torch.zeros(20))
pyro.param("AutoLowRankMultivariateNormal.scale", torch.ones(20))
pyro.param("AutoLowRankMultivariateNormal.cov_factor", torch.ones(20,2))

guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)

elbo = pyro.infer.Trace_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)
# ...
# AssertionError

@gui11aume gui11aume changed the title Subsampling in AutoDiscreteParallel produces parameters with wrong shapes Subsampling in some autoguides produces parameters with wrong shapes Nov 2, 2023
@martinjankowiak
Copy link
Collaborator

AutoGuides + data subsampling requires using create_plates see e.g. this test

@gui11aume
Copy link
Author

Thanks @martinjankowiak! I have tried creating the plates manually in different contexts, but I did not get any luck. Have a look at the example below: am I doing it wrong?

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Categorical(torch.ones(1)))

def create_plate_x():
   return pyro.plate("dummy", 20, subsample_size=3, dim=-1)

guide = pyro.infer.autoguide.AutoDiscreteParallel(model, create_plates=create_plate_x)

elbo = pyro.infer.TraceEnum_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)


print(pyro.param("AutoDiscreteParallel.x_probs").shape)
# torch.Size([3, 1])

Thanks for the link to the test! It seems to run with AutoDetla and AutoNormal but I never had problems with those; I think they work fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants