Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got runtime error when using hmc / mcmc together with sequential enumeration #3343

Open
ljlin opened this issue Mar 19, 2024 · 0 comments
Open
Labels

Comments

@ljlin
Copy link
Contributor

ljlin commented Mar 19, 2024

I am trying to implement a CRBD model, which contains both continuous and discrete random variables, described in this paper and apply HMC to it.

But I got a runtime error that says

ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).

Is this I use Pyro's HMC in the wrong way, or it's Pyro's HMC not compatible with {'enumerate': 'sequential'}?

What should I do to apply HMC together with {'enumerate': 'sequential'} to this CRBD model?

Thanks for helping.

Code:

import argparse
import pyro.distributions as dist
import pyro
import torch
from pyro.infer import MCMC, NUTS
import sys
sys. setrecursionlimit(32767)

def gosExtince(prefix, time, la, mu):
    waitingTime = pyro.sample(f"{prefix}/waitingTime", dist.Exponential(la))
    if waitingTime > time:
        b_waitingTime = False
    else:
        isSpeciation = pyro.sample(f"{prefix}/isSpeciation", dist.Bernoulli(la / (la + mu)), infer={'enumerate': 'sequential'})
        # ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).
        if isSpeciation: # https://pyro.ai/examples/enumeration.html
            x = gosExtince(f"{prefix}/x", time - waitingTime, la, mu)
            y = gosExtince(f"{prefix}/y", time - waitingTime, la, mu)
            b_isSpeciation = x and y
        else:
            b_isSpeciation = True
        b_waitingTime = b_isSpeciation
    return b_waitingTime

def model(time):
    la = pyro.sample("lamda", dist.Gamma(1, 1))
    mu = pyro.sample("mu", dist.Gamma(1, 1))
    obs = gosExtince(".", time, la, mu)
    pyro.factor("obs", torch.ones(1) if obs else -torch.inf)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', choices=["HMC", "MAPPL"],  required=True)
    parser.add_argument('--time', type=float, required=True)
    parser.add_argument('--progress_bar', action='store_true')
    parser.add_argument('--num_chains', type=int, required=True)
    parser.add_argument('--warmup_steps', type=int, required=True)
    parser.add_argument('--num_samples', type=int, required=True)
    args = parser.parse_args()
    print(args)

    if args.config == "HMC":
        nuts_kernel = NUTS(model)

    mcmc = MCMC(
        nuts_kernel,
        warmup_steps=args.warmup_steps,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        disable_progbar=not args.progress_bar
    )
    mcmc.run(
        args.time
    )
    mcmc.print_summary()

if __name__ == '__main__':
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants