Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: pymc.sample_smc fails with pymc.CustomDist #7224

Open
EliasRas opened this issue Mar 27, 2024 · 2 comments · May be fixed by #7241
Open

BUG: pymc.sample_smc fails with pymc.CustomDist #7224

EliasRas opened this issue Mar 27, 2024 · 2 comments · May be fixed by #7241
Labels
bug SMC Sequential Monte Carlo

Comments

@EliasRas
Copy link

EliasRas commented Mar 27, 2024

Describe the issue:

pymc.sample_smc raises a NotImplementedError due to a missing logp method if a pymc.CustomDist is used in a model without dist argument. In addition to using dist, switching to pm.Potential works.

Reproduceable code example:

import pymc as pm
import numpy as np


def _logp(value, mu, sigma):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)

    return pm.logp(dist, value)


def _random(mu, sigma, rng, size):
    if rng is None:
        rng = np.random.default_rng()
    sample = rng.normal(loc=mu, scale=sigma, size=size)

    return sample


def _logcdf(value, mu, sigma):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)

    return pm.logcdf(dist, value)


def _dist(mu, sigma, size):
    return pm.Normal.dist(mu, sigma, size=size)


def main():
    data = np.random.default_rng().normal(5, 2, 1000)

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.CustomDist(
            "y",
            mu,
            sigma,
            logp=_logp,
            random=_random,
            logcdf=_logcdf,
            observed=data,
        )
        sample = pm.sample_smc()  # NotImplementedError

    with pm.Model():
        pm.CustomDist(
            "y",
            2,
            10,
            logp=_logp,
            random=_random,
            logcdf=_logcdf,
        )
        sample = pm.sample_smc()  # NotImplementedError

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.CustomDist(
            "y",
            mu,
            sigma,
            dist=_dist,
            observed=data,
        )
        sample = pm.sample_smc()  # Works

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.Potential(
            "y",
            _logp(data, mu, sigma),
        )
        sample = pm.sample_smc()  # Works


if __name__ == "__main__":
    main()

Error message:

multiprocessing.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 421, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 344, in _sample_smc_int
    smc._initialize_kernel()
  File "\envs\pymc\Lib\site-packages\pymc\smc\kernels.py", line 239, in _initialize_kernel
    initial_point, [self.model.varlogp], self.variables, shared
                    ^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 832, in varlogp
    return self.logp(vars=self.free_RVs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 717, in logp
    rv_logps = transformed_conditional_logp(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 612, in transformed_conditional_logp
    temp_logp_terms = conditional_logp(
                      ^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 542, in conditional_logp
    q_logprob_vars = _logprob(
                     ^^^^^^^^^
  File "\envs\pymc\Lib\functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\logprob\abstract.py", line 63, in _logprob
    raise NotImplementedError(f"Logprob method not implemented for {op}")
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "issue.py", line 79, in <module>
    main()
  File "issue.py", line 44, in main
    sample = pm.sample_smc()  # Exception has occurred: NotImplementedError
             ^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 213, in sample_smc
    results = run_chains_parallel(
              ^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 390, in run_chains_parallel
    results = _starmap_with_kwargs(
              ^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 417, in _starmap_with_kwargs
    return pool.starmap(_apply_args_and_kwargs, args_for_starmap)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 375, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 774, in get
    raise self._value
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}

PyMC version information:

Python 3.11.7
pymc 5.10.0
pytensor 2.18.6
Win 10
Environment set up via conda but updated pymc and pytensor with pip

Context for the issue:

I'm testing a model which suffers from slow sampling, possibly due to expensive gradient calculations. I tested SMC as a possible solution as suggested on the forums but got this error message.

Using the dist argument could work in most cases, but there's cases when the distributions provided by pymc are not enough. Using pm.Potential could help with sampling but that would in turn make forward sampling less straightforward.

@EliasRas EliasRas added the bug label Mar 27, 2024
Copy link

welcome bot commented Mar 27, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@EliasRas
Copy link
Author

EliasRas commented Apr 5, 2024

I tested pm.sample_smc(cores=1) and got no error which made me dig a bit deeper. If I understood correctly, the error with multiple processes happens because e.g. logp gets registered only in the main process. Would it be possible to make a initializer for the pool used in pm.smc.sampling.run_chains_parallel which ensures that the methods are registered properly?

@EliasRas EliasRas linked a pull request Apr 6, 2024 that will close this issue
11 tasks
@ricardoV94 ricardoV94 added the SMC Sequential Monte Carlo label Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug SMC Sequential Monte Carlo
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants