Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user experience in inference with partially observed discrete variables #3255

Open
gui11aume opened this issue Aug 6, 2023 · 7 comments

Comments

@gui11aume
Copy link

gui11aume commented Aug 6, 2023

Summary

Inference for partially observed discrete variables occasionally produces some counter-intuitive results. Those are not bugs but users may waste a lot of time dealing with them or trying to understand them. The behavior has been tested on Pyro 1.8.5 and 1.8.6.

A simple example with coins

The example below is meant to show in which kind of context the issues appear. It is artificial and has no practical applications, but it is inspired from real examples I stumbled upon. In the model, we flip a fair coin and do not show the result; if it lands 'heads' we flip a coin with bias 0.05; if it lands 'tails' we flip a coin with bias 0.95. We always observe the result of the biased coin (but not which coin was flipped). In the guide, we simply sample the unbiased coin.

import pyro
import pyro.distributions as dist
import torch

def model(X):
   with pyro.plate("tosses", X.shape[0]):
      c = pyro.sample("unbiased", dist.Categorical(torch.ones(1, 2)))
      pyro.sample("obs", dist.Bernoulli(torch.tensor([.05,.95])[c]), obs=X)

def guide(X):
   post_p = pyro.param("post_p", torch.ones(X.shape[0], 2),
      constraint = torch.distributions.constraints.simplex)
   with pyro.plate("tosses", X.shape[0]):
      pyro.sample("unbiased", dist.Categorical(post_p))

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   svi.step(torch.tensor([1.])) # X = 1. stands for "tails".
print(pyro.param("post_p"))

# tensor([[0.0552, 0.9448]], grad_fn=<DivBackward0>)

The result is correct, the second coin landed 'tails' so the posterior probability that the unbiased coin landed 'tails' is 0.95.

Issue 1: Code failure when masking

If the second coin is sometimes observed, we can introduce an observation mask for the obs sample. Let us modify the code and run the same example, i.e., we specify that the second coin landed 'tails' and this is observed.

import pyro
import pyro.distributions as dist
import torch

# Add `mask` argument, together with `obs_mask=mask` in sample "obs".
def model(X, mask):
   with pyro.plate("tosses", X.shape[0]):
      c = pyro.sample("unbiased", dist.Categorical(torch.ones(1, 2)))
      pyro.sample("obs", dist.Bernoulli(torch.tensor([.05,.95])[c]), obs=X, obs_mask=mask)

def guide(X, mask):
   post_p = pyro.param("post_p", torch.ones(X.shape[0], 2),
      constraint = torch.distributions.constraints.simplex)
   with pyro.plate("tosses", X.shape[0]):
      pyro.sample("unbiased", dist.Categorical(post_p))

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   svi.step(torch.tensor([1.]), torch.tensor([True])) # True means observed.
print(pyro.param("post_p"))

The code fails with the error below.

(...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'obs_unobserved'}
  warnings.warn(f"Found vars in model but not guide: {bad_sites}")
Traceback (most recent call last):
  File "(...)", line 18, in <module>
    svi.step(torch.tensor([1.]), torch.tensor([True]))
  File "(...)/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "(...)/pyro/infer/trace_elbo.py", line 141, in loss_and_grads
    loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
  File "(...)/pyro/infer/trace_elbo.py", line 106, in _differentiable_loss_particle
    log_r = _compute_log_r(model_trace, guide_trace)
  File "(...)/pyro/infer/trace_elbo.py", line 27, in _compute_log_r
    log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
KeyError: 'obs_unobserved'

First there is a warning for missing site in the guide and then a KeyError for the same reason. This is counter-intuitive: either guides with missing sites should be allowed (warn only), or they should not (raise an error for every guide with missing sites).

The error comes from a part of the code that evaluates the loss using the REINFORCE estimator (i.e., when the reparametrization trick cannot be used, as in the case of discrete random variables). Line 27 in trace_elbo.py assumes that every unobserved site in the model also exists in the guide. The user may not be aware that the site obs_unobserved is created in the model (but not in the guide) as soon as the argument obs_mask is not None.

The solution is to define the sample obs_unobserved in the guide (see how below), but there are barely any mentions of this, so we cannot assume that users will do it. If guides with missing sites are allowed, line 27 in trace_elbo.py should be replaced with a fail-safe version. Ideally, a message could point users in the right direction if Pyro creates an _unobserved site that is not in the guide.

Issue 2: Counter-intuitive gradient

Now if the unbiased coin is sometimes observed, we can introduce an observation mask for the unbiased sample, together with some observations when they are available. As mentioned above, we need to add a site in the guide called unbiased_unobserved explaining what to do when the coin is not observed (i.e., sample it as we were doing until now). We have to sample the whole tensor; Pyro will automatically mix in observed and sampled values for us as needed.

Some values in unbiased_unobserved are sampled for nothing: they will be replaced with the observed values if they are available. In this case, the sampled values have no effect on the inference, but just to be sure, we are going to mask them in the guide to set their log_prob terms to 0. We do this by using poutine.mask where we invert the observation mask with ~mask.

import pyro
import pyro.distributions as dist
import torch

# Add `coin_obs` argument; now `obs_mask` applies to the unbiased coin.
def model(X, coin_obs, mask):
   with pyro.plate("tosses", X.shape[0]):
      c = pyro.sample("unbiased", dist.Categorical(torch.ones(1, 2)), obs=coin_obs, obs_mask=mask)
      pyro.sample("obs", dist.Bernoulli(torch.tensor([.05,.95])[c]), obs=X)

def guide(X, coin_obs, mask):
   post_p = pyro.param("post_p", torch.ones(X.shape[0], 2), 
      constraint = torch.distributions.constraints.simplex)
   with pyro.plate("tosses", X.shape[0]):
      with pyro.poutine.mask(mask=~mask): # Invert the observation mask.
         pyro.sample("unbiased_unobserved", dist.Categorical(post_p))

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   # Second coin: tails, heads, tails. Unbiased coin: tails, heads, ?
   svi.step(torch.tensor([1., 0., 1.]), torch.tensor([1, 0, 0]), torch.tensor([True, True, False]))
print(pyro.param("post_p"))

# tensor([[0.3058, 0.6942],
#         [0.7498, 0.2502],
#         [0.0502, 0.9498]], grad_fn=<DivBackward0>)

In this example, we observed the first two flips of the unbiased coin, but not the third. We set the value to heads with 0 but this is irrelevant because the value is never used throughout the inference. The inference is correct for the third flip and there is nothing to infer for the first two flips because the values were observed... So why did the values of post_p change from the initial 0.5 and what do the current values represent?

As far as I understand, the values have no special meaning. There is nothing to infer anyway. So why did they change? Once again, this has to do with the way Pyro evaluates the loss using the REINFORCE estimator. Internally, it keeps track of a log_prob term and a score_function term for the sites of the guide. The log_prob terms are masked but not the score_function terms, so all the values of unbiased_unobserved contribute to the gradient, even those that are overwritten by observed values.

I don't think that this has side effects, so this is not really a bug. The issue here is that Pyro is difficult enough to debug, and erratic behaviors make it harder. It would help if parameters that have no effect on the inference have gradient 0, so that the user gets alerted when there is an error in the model (e.g., when values that should have no effect on the inference do in fact have an effect).

@gui11aume
Copy link
Author

gui11aume commented Aug 6, 2023

@fritzo It seems that you wrote the following comment in pyro/distributions/score_parts.py

71b7d3da0 (Fritz Obermeyer 2018-08-26 10:51:42 -0700 17)     def scale_and_mask(self, scale=1.0, mask=None):
71b7d3da0 (Fritz Obermeyer 2018-08-26 10:51:42 -0700 19)         Scale and mask appropriate terms of a gradient estimator by a data multiplicity factor.
71b7d3da0 (Fritz Obermeyer 2018-08-26 10:51:42 -0700 20)         Note that the `score_function` term should not be scaled or masked.

I ran the tests in pyro/tests/infer/test_inference.py with masking and/or scaling score_function but there was no difference so I could not understand why score_function should not be masked. Maybe there is a test somewhere else that depends on this? If not, do you remember the case you had in mind when you designed this part?

If you point me in the right direction I can write additional tests and work on a pull request for this issue. Thanks!

@fritzo
Copy link
Member

fritzo commented Aug 10, 2023

Hi @gui11aume, responding to your last question: the best tests we have of ScoreParts behavior are in tests/infer/test_gradient.py. We have heavily relied on those test because they are much faster than end-to-end inference-as-optimization tests since the gradient tests compute only a single gradient update. Because the gradient tests are fast, we can run them over a large grid of model and inference configurations via @pytest.mark.parametrize.

For example, this incorrect change to ScoreParts:

diff --git a/pyro/distributions/score_parts.py b/pyro/distributions/score_parts.py
index 15d39156..bb758d82 100644
--- a/pyro/distributions/score_parts.py
+++ b/pyro/distributions/score_parts.py
@@ -25,6 +25,6 @@ class ScoreParts(
         :type mask: torch.BoolTensor or None
         """
         log_prob = scale_and_mask(self.log_prob, scale, mask)
-        score_function = self.score_function  # not scaled
+        score_function = scale_and_mask(self.score_function, scale, mask)
         entropy_term = scale_and_mask(self.entropy_term, scale, mask)
         return ScoreParts(log_prob, score_function, entropy_term)

results in a test failure in around one second:

% pytest -vsx tests/infer/test_gradient.py
...
______________________________ test_subsample_gradient[Trace_ELBO-False-full-reparam-False-scaled] _______________________________

Elbo = <class 'pyro.infer.trace_elbo.Trace_ELBO'>, reparameterized = True, has_rsample = False, subsample = False
local_samples = False, scale = 2.0

    @pytest.mark.parametrize("scale", [1.0, 2.0], ids=["unscaled", "scaled"])
    @pytest.mark.parametrize(
        "reparameterized,has_rsample",
        [(True, None), (True, False), (True, True), (False, None)],
        ids=["reparam", "reparam-False", "reparam-True", "nonreparam"],
    )
    @pytest.mark.parametrize("subsample", [False, True], ids=["full", "subsample"])
    @pytest.mark.parametrize(
        "Elbo,local_samples",
        [
            (Trace_ELBO, False),
            (DiffTrace_ELBO, False),
            (TraceGraph_ELBO, False),
            (TraceMeanField_ELBO, False),
            (TraceEnum_ELBO, False),
            (TraceEnum_ELBO, True),
        ],
    )
    def test_subsample_gradient(
        Elbo, reparameterized, has_rsample, subsample, local_samples, scale
    ):
        pyro.clear_param_store()
        data = torch.tensor([-0.5, 2.0])
        subsample_size = 1 if subsample else len(data)
        precision = 0.06 * scale
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

        def model(subsample):
            with pyro.plate("data", len(data), subsample_size, subsample) as ind:
                x = data[ind]
                z = pyro.sample("z", Normal(0, 1))
                pyro.sample("x", Normal(z, 1), obs=x)

        def guide(subsample):
            scale = pyro.param("scale", lambda: torch.tensor([1.0]))
            with pyro.plate("data", len(data), subsample_size, subsample):
                loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0)
                z_dist = Normal(loc, scale)
                if has_rsample is not None:
                    z_dist.has_rsample_(has_rsample)
                pyro.sample("z", z_dist)

        if scale != 1.0:
            model = poutine.scale(model, scale=scale)
            guide = poutine.scale(guide, scale=scale)

        num_particles = 50000
        if local_samples:
            guide = config_enumerate(guide, num_samples=num_particles)
            num_particles = 1

        optim = Adam({"lr": 0.1})
        elbo = Elbo(
            max_plate_nesting=1,  # set this to ensure rng agrees across runs
            num_particles=num_particles,
            vectorize_particles=True,
            strict_enumeration_warning=False,
        )
        inference = SVI(model, guide, optim, loss=elbo)
        with xfail_if_not_implemented():
            if subsample_size == 1:
                inference.loss_and_grads(
                    model, guide, subsample=torch.tensor([0], dtype=torch.long)
                )
                inference.loss_and_grads(
                    model, guide, subsample=torch.tensor([1], dtype=torch.long)
                )
            else:
                inference.loss_and_grads(
                    model, guide, subsample=torch.tensor([0, 1], dtype=torch.long)
                )
        params = dict(pyro.get_param_store().named_parameters())
        normalizer = 2 if subsample else 1
        actual_grads = {
            name: param.grad.detach().cpu().numpy() / normalizer
            for name, param in params.items()
        }

        expected_grads = {
            "loc": scale * np.array([0.5, -2.0]),
            "scale": scale * np.array([2.0]),
        }
        for name in sorted(params):
            logger.info("expected {} = {}".format(name, expected_grads[name]))
            logger.info("actual   {} = {}".format(name, actual_grads[name]))
>       assert_equal(actual_grads, expected_grads, prec=precision)

tests/infer/test_gradient.py:214:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/common.py:263: in assert_equal
    return assert_close(actual, expected, atol=prec, msg=msg)
tests/common.py:243: in assert_close
    assert_close(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

actual = array([7.8726236]), expected = array([4.]), atol = 0.12, rtol = 0, msg = "At key 'scale': [7.8726236] vs [4.]"

    def assert_close(actual, expected, atol=1e-7, rtol=0, msg=""):
        if not msg:
            msg = "{} vs {}".format(actual, expected)
        if isinstance(actual, numbers.Number) and isinstance(expected, numbers.Number):
            assert actual == approx(expected, abs=atol, rel=rtol), msg
        # Placing this as a second check allows for coercing of numeric types above;
        # this can be moved up to harden type checks.
        elif type(actual) != type(expected):
            raise AssertionError(
                "cannot compare {} and {}".format(type(actual), type(expected))
            )
        elif torch.is_tensor(actual) and torch.is_tensor(expected):
            prec = atol + rtol * abs(expected) if rtol > 0 else atol
            assert actual.is_sparse == expected.is_sparse, msg
            if actual.is_sparse:
                x = _safe_coalesce(actual)
                y = _safe_coalesce(expected)
                assert_tensors_equal(x._indices(), y._indices(), prec, msg)
                assert_tensors_equal(x._values(), y._values(), prec, msg)
            else:
                assert_tensors_equal(actual, expected, prec, msg)
        elif type(actual) == np.ndarray and type(expected) == np.ndarray:
>           assert_allclose(
                actual, expected, atol=atol, rtol=rtol, equal_nan=True, err_msg=msg
E               AssertionError:
E               Not equal to tolerance rtol=0, atol=0.12
E               At key 'scale': [7.8726236] vs [4.]
E               Mismatched elements: 1 / 1 (100%)
E               Max absolute difference: 3.8726236
E               Max relative difference: 0.9681559
E                x: array([7.872624])
E                y: array([4.])

tests/common.py:235: AssertionError
------------------------------------------------------- Captured log call --------------------------------------------------------
INFO     tests.infer.test_gradient:test_gradient.py:212 expected loc = [ 1. -4.]
INFO     tests.infer.test_gradient:test_gradient.py:213 actual   loc = [ 1.93714977 -7.87144379]
INFO     tests.infer.test_gradient:test_gradient.py:212 expected scale = [4.]
INFO     tests.infer.test_gradient:test_gradient.py:213 actual   scale = [7.8726236]
==================================================== short test summary info =====================================================
FAILED tests/infer/test_gradient.py::test_subsample_gradient[Trace_ELBO-False-full-reparam-False-scaled] - AssertionError:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
================================================== 1 failed, 11 passed in 0.25s ==================================================

@gui11aume
Copy link
Author

Thanks @fritzo! That's so useful to understand how the code works. I'll start from there and think of ways to address the issues without breaking everything. By the way, I am reading your work of the past few years and I am so impressed. I don't say this very often, but you are an example to look up to.

@gui11aume
Copy link
Author

gui11aume commented Aug 11, 2023

I think I understand why the score.function terms are not scaled. Let me know if this is correct @fritzo.

What you call log_r in the code must refer to $\log p(x,z) - \log q_\theta(z)$, i.e., the integrand of the ELBO as in this article. With the REINFORCE estimator, the target is to compute $\nabla_\theta \log q_\theta(z) \cdot \log r + \nabla_\theta \log r$, which is done by differentiating the surrogate loss $\log q_\theta(z) \cdot \overline{\log r} + \log r$, where the horizontal bar means that the term is treated as a constant. The score.function terms store the values of $\log q_\theta(z)$ and $\overline{\log r}$ is computed by the function _compute_log_r (in the file infer/trace_elbo.py). The term $\log r$ is computed beforehand by subtracting all the $\log p(x,z)$ terms (i.e., log_prob in the model sites) with the $\log q_\theta(z)$ terms (i.e., log_prob in the guide sites).

Loosely speaking, if we estimate the gradient of the ELBO with only half of the terms, then we have to multiply $\log r$ by 2 to maintain the expected value. But we should not multiply $\log q_\theta(z)$ by 2, otherwise we estimate $4 \nabla_\theta \log q_\theta(z) \cdot \log r + 2 \nabla_\theta \log r$. So, internally you multiply the log_prob terms by 2 (in the model and in the guide), which updates $\log r$ but leaves $\log q_\theta(z)$ as is.

As far as I understand, scaling and masking are processed together in scale_and_mask, but I think they obey different rules. I'll clarify this in a separate comment. For now I just want to mention that the tests seem to pass when you mask the score_function terms but do not scale them.

-        score_function = self.score_function  # not scaled
+        score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled

I believe that this is because there are no tests to check the gradient when masking is active. I'll see if I can write some for this case, along the lines of those in test_subsample_gradient.

@fritzo
Copy link
Member

fritzo commented Aug 17, 2023

@gui11aume yes, your explanation sounds right. I'm sorry I don't recall why masking behaves differently from scaling; I vaguely recall there was a reason, but I forget whether that reason was due to deep mathematics or merely incidental complexity in our implementation. Additional tests would be great!

@gui11aume
Copy link
Author

Thanks for confirming @fritzo. Below is a very long post, I don't expect anyone to read it. It is mostly here for reference, to keep track of my rationale.

Issue 1: Code failure when masking

After some thinking, my opinion is that it should be allowed to have sites in the model but not in the guide (the bad_sites), because some cases are fully legitimate. To build one, we start a typical example with a global Gaussian parameter and local Gaussian observations (no missing observation or masking for now).

import pyro
import pyro.distributions as dist
import torch

def model(data):
    z = pyro.sample("z", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("x", dist.Normal(z, 1), obs=data)
    return

def guide(data):
    loc = pyro.param("loc", lambda: torch.tensor([0.]))
    scale = pyro.param("scale", lambda: torch.tensor([1.]))
    z_dist = dist.Normal(loc, scale)
    pyro.sample("z", z_dist)

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   svi.step(torch.tensor([2.]))
print(pyro.param("loc"), pyro.param("scale"))

# tensor([0.9798], requires_grad=True) tensor([0.7382], requires_grad=True)

The inference is correct: in this case the posterior distribution is $N(\mu=x/2, \sigma=1/\sqrt{2})$. Now, say that some observations are missing. We just need to add an obs_mask field to the sample x.

import pyro
import pyro.distributions as dist
import torch

def model(data, mask):
    z = pyro.sample("z", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
    return

def guide(data, mask):
    loc = pyro.param("loc", lambda: torch.tensor([0.]))
    scale = pyro.param("scale", lambda: torch.tensor([1.]))
    z_dist = dist.Normal(loc, scale)
    pyro.sample("z", z_dist)

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc"), pyro.param("scale"))

# (...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
# tensor([0.9631], requires_grad=True) tensor([0.7807], requires_grad=True)

We get a warning but the inference is correct. In the first model, we have only one observation (x = 2.0). In the second model, we have two observations, but only the first is known (still x = 2.0). The cases are indistinguishable. Notice how with masking, some log-likelihoods have to be removed, which shifts the expected ELBO. This is an important difference with scaling.

Moving on, if a masked variable has no r_sample (it cannot be reparametrized to obtain pathwise derivatives), then pyro crashes. Below we artificially deactivate has_r_sample for the Gaussian to force pyro to use the REINFORCE estimator.

import pyro
import pyro.distributions as dist
import torch

def model(data, mask):
    z = pyro.sample("z", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
    return

def guide(data, mask):
    loc = pyro.param("loc", lambda: torch.tensor([0.]))
    scale = pyro.param("scale", lambda: torch.tensor([1.]))
    z_dist = dist.Normal(loc, scale)
    z_dist.has_rsample = False # <== Pretend we cannot use the reparametrization trick.
    pyro.sample("z", z_dist)

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc"), pyro.param("scale"))

# (...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
#   warnings.warn(f"Found vars in model but not guide: {bad_sites}")
# Traceback (most recent call last):
#   File "tmp3.py", line 20, in <module>
#     svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
#   File "(...)/pyro/infer/svi.py", line 145, in step
#     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
#   File "(...)/pyro/infer/trace_elbo.py", line 141, in loss_and_grads
#     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
#   File "(...)/pyro/infer/trace_elbo.py", line 106, in _differentiable_loss_particle
#     log_r = _compute_log_r(model_trace, guide_trace)
#   File "(...)/pyro/infer/trace_elbo.py", line 27, in _compute_log_r
#     log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
# KeyError: 'x_unobserved'

The code should have the same behavior as before, masking has nothing to do with r_sample in theory. I found three places where the code can fail: pyro/infer/trace_elbo.py:27, pyro/infer/trace_mean_field_elbo.py:112 and pyro/infer/tracegraph_elbo.py:221. In all three cases, I think that the solution is to do nothing if a site of the model is not found in the guide (tests are on the way).

# https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/trace_elbo.py#L20C1-L29C17
def _compute_log_r(model_trace, guide_trace):
    log_r = MultiFrameTensor()
    stacks = get_plate_stacks(model_trace)
    for name, model_site in model_trace.nodes.items():
        if model_site["type"] == "sample":
            log_r_term = model_site["log_prob"]
            if not model_site["is_observed"]:
                log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] # <== This can fail.
            log_r.add((stacks[name], log_r_term.detach()))
    return log_r
# https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/trace_mean_field_elbo.py#L107C1-L114C63
        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    elbo_particle = elbo_particle + model_site["log_prob_sum"]
                else:
                    guide_site = guide_trace.nodes[name] # <== This can fail.
                    if is_validation_enabled():
                        check_fully_reparametrized(guide_site)
# https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/tracegraph_elbo.py#L217C1-L223C66
    # construct all the reinforce-like terms.
    # we include only downstream costs to reduce variance
    # optionally include baselines to further reduce variance
    for node, downstream_cost in downstream_costs.items():
        guide_site = guide_trace.nodes[node] # <== This can fail.
        downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
        score_function = guide_site["score_parts"].score_function

Issue 2: Counter-intuitive gradient

Let us go back to the second case, where the reparametrization trick is available, and let us try to infer the distribution of the missing value of x. As we have seen, we need to add to the guide a sample called x_unobserved. We also wrap it in a poutine.mask with a mask in mirror-image of the observations because we have to sample observed values even if we do not need them. This way, the dummy samples have no influence at all.

import pyro
import pyro.distributions as dist
import torch

def model(data, mask):
    z = pyro.sample("z", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
    return

def guide(data, mask):
    loc = pyro.param("loc", lambda: torch.tensor([0.]))
    scale = pyro.param("scale", lambda: torch.tensor([1.]))
    z_dist = dist.Normal(loc, scale)
    pyro.sample("z", z_dist)
    with pyro.plate("data", len(data)):
        loc_x = pyro.param("loc_x", lambda: torch.tensor([0., 0.]))
        scale_x = pyro.param("scale_x", lambda: torch.tensor([1., 1.]))
        with pyro.poutine.mask(mask=~mask):
            pyro.sample("x_unobserved", dist.Normal(loc_x, scale_x)) 

svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
   svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc_x"), pyro.param("scale_x"))

# tensor([0.0000, 0.9787], requires_grad=True) tensor([1.0000, 1.0632], requires_grad=True)

As expected, the warning is gone. We also see that parameters for the observed values are exactly as they were initialized, meaning that for these values, the gradient was 0 throughout, as expected. It is therefore the behavior that should be achieved when the variables have no r_sample.

This can be done with the update mentioned above, where the terms of score_function are masked but not scaled.

-        score_function = self.score_function  # not scaled
+        score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled

I have written some new tests and I will open a pull request draft shortly.

gui11aume added a commit to gui11aume/pyro that referenced this issue Aug 30, 2023
@gui11aume
Copy link
Author

I have opened pull request #3265 with the changes discussed above and some new tests involving gradients with poutine.mask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants