You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
pymc.sample_smc raises a NotImplementedError due to a missing logp method if a pymc.CustomDist is used in a model without dist argument. In addition to using dist, switching to pm.Potential works.
multiprocessing.pool.RemoteTraceback:
"""Traceback (most recent call last): File "\envs\pymc\Lib\multiprocessing\pool.py", line 125, in worker result = (True, func(*args, **kwds)) ^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\multiprocessing\pool.py", line 51, in starmapstar return list(itertools.starmap(args[0], args[1])) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 421, in _apply_args_and_kwargs return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 344, in _sample_smc_int smc._initialize_kernel() File "\envs\pymc\Lib\site-packages\pymc\smc\kernels.py", line 239, in _initialize_kernel initial_point, [self.model.varlogp], self.variables, shared ^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 832, in varlogp return self.logp(vars=self.free_RVs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 717, in logp rv_logps = transformed_conditional_logp( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 612, in transformed_conditional_logp temp_logp_terms = conditional_logp( ^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 542, in conditional_logp q_logprob_vars = _logprob( ^^^^^^^^^ File "\envs\pymc\Lib\functools.py", line 909, in wrapper return dispatch(args[0].__class__)(*args, **kw) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\logprob\abstract.py", line 63, in _logprob raise NotImplementedError(f"Logprob method not implemented for {op}")NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "issue.py", line 79, in<module>main()
File "issue.py", line 44, in main
sample = pm.sample_smc() # Exception has occurred: NotImplementedError
^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 213, in sample_smc
results = run_chains_parallel(
^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 390, in run_chains_parallel
results = _starmap_with_kwargs(
^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 417, in _starmap_with_kwargs
return pool.starmap(_apply_args_and_kwargs, args_for_starmap)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\multiprocessing\pool.py", line 375, in starmap
return self._map_async(func, iterable, starmapstar, chunksize).get()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\multiprocessing\pool.py", line 774, in get
raise self._value
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}
PyMC version information:
Python 3.11.7
pymc 5.10.0
pytensor 2.18.6
Win 10
Environment set up via conda but updated pymc and pytensor with pip
Context for the issue:
I'm testing a model which suffers from slow sampling, possibly due to expensive gradient calculations. I tested SMC as a possible solution as suggested on the forums but got this error message.
Using the dist argument could work in most cases, but there's cases when the distributions provided by pymc are not enough. Using pm.Potential could help with sampling but that would in turn make forward sampling less straightforward.
The text was updated successfully, but these errors were encountered:
I tested pm.sample_smc(cores=1) and got no error which made me dig a bit deeper. If I understood correctly, the error with multiple processes happens because e.g. logp gets registered only in the main process. Would it be possible to make a initializer for the pool used in pm.smc.sampling.run_chains_parallel which ensures that the methods are registered properly?
Describe the issue:
pymc.sample_smc
raises aNotImplementedError
due to a missing logp method if apymc.CustomDist
is used in a model withoutdist
argument. In addition to usingdist
, switching topm.Potential
works.Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
I'm testing a model which suffers from slow sampling, possibly due to expensive gradient calculations. I tested SMC as a possible solution as suggested on the forums but got this error message.
Using the
dist
argument could work in most cases, but there's cases when the distributions provided bypymc
are not enough. Usingpm.Potential
could help with sampling but that would in turn make forward sampling less straightforward.The text was updated successfully, but these errors were encountered: