Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a distribution as choice map #530

Open
lwang19-ai opened this issue Apr 23, 2024 · 1 comment
Open

Use a distribution as choice map #530

lwang19-ai opened this issue Apr 23, 2024 · 1 comment

Comments

@lwang19-ai
Copy link

We are using particle filter in gen.
We want to use the posterior distribution from one model as a choice map to update another model. If we have 50 particles, we will get 50 observations for the same state at the same time. How shall we deal with this in Gen? Is that possible to implement something like stochastic conditioning?
Thanks a lot.

@ztangent
Copy link
Member

Hi! Great question. Gen.jl doesn't have support for stochastic conditioning, but here's something you could do instead that comes close (but is not mathematically identical). I'm going to assume that your first model is "smaller" than your second model, in the sense that all random variables in the first model also exist in the second model. For example:

Model 1: $y \sim p_1(y), z \sim p_1(z | y)$, $z$ is observed.
Model 2: $x \sim p_2(x), y \sim p_2(y | x), z \sim p_2(z | x, y)$, $z$ is observed.

Now let's say you use a particle filter to approximate the posterior $p_1(y | x)$ as a weighted particle collection $[(y_i, w_i)]_{i=1}^M$. If you sample $y$ from this weighted particle collection by drawing particles according to their weights (using e.g. Gen.sample_unweighted_traces ), you can view this as sampling $y$ from a distribution $q(y | x)$ that approximates $p_1(y | x)$.

Furthermore, we can use the log marginal likelihood estimate $\hat Z$ returned by the particle filter (accessible with Gen.log_ml_estimate) to form an unbiased estimate of the density $q(y|x)$ (and also an unbiased estimate of its reciprocal, which is required for importance sampling):

$$q(y | x) = \mathbb{E}\left[\frac{p_1(y, z)}{\hat Z}\right], \qquad \frac{1}{q(y | x)} = \mathbb{E}\left[\frac{\hat Z}{p_1(y, z)}\right]$$

All of this means we can use our particle filter as a proposal distribution for the value of $y$ in Model 2. We can combine this with a different proposal distribution $q(x)$ for the value of $x$ (in the simplest case, we can just use $q(x) = p_2(x)$ ), in order to do importance sampling to approximate $p_2(x, y | z)$:

  1. Propose $x \sim q(x)$
  2. Propose $y \sim q(y | z)$ by running the particle filter given the observations $z$, and then sampling $y$.
  3. Compute the importance weight: $w = \frac{p_2(x, y, z)}{q(x) \hat q(y | z)}$, where $\hat q(y | z) = \mathbb{E}\left[\frac{p_1(y, z)}{\hat Z}\right]$.

Do this $N$ times and now you have a weighted collection of $N$ samples that approximates $p_2(x, y | z)$.

I hope this helps! Some of this stuff is easier to implement using a variant of Gen called GenSP.jl (see https://github.com/probcomp/GenSP.jl), but everything I described above is possible to do using the existing functionality provided by Gen.jl. Besides using Gen.log_ml_estimate and Gen.sample_unweighted_traces, you'll also need to use Gen.get_score to evaluate $p_1(y, z)$ for a particular trace from Model 1.

I would recommend checking out the source code of Gen.importance_sampling as a reference:

function importance_sampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction,
proposal_args::Tuple,
num_samples::Int;
verbose=false) where {T,U}
traces = Vector{U}(undef, num_samples)
log_weights = Vector{Float64}(undef, num_samples)
for i=1:num_samples
verbose && println("sample: $i of $num_samples")
(proposed_choices, proposal_weight, _) = propose(proposal, proposal_args)
constraints = merge(observations, proposed_choices)
(traces[i], model_weight) = generate(model, model_args, constraints)
log_weights[i] = model_weight - proposal_weight
end
log_total_weight = logsumexp(log_weights)
log_ml_estimate = log_total_weight - log(num_samples)
log_normalized_weights = log_weights .- log_total_weight
return (traces, log_normalized_weights, log_ml_estimate)
end

The main thing you have to modify is how the samples are proposed (by sampling from a particle filter, instead of using Gen.propose), and how both the proposal_weight and the importance log_weight are calculated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants