Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Jax-based samplers crash at transformation stage #6744

Open
fonnesbeck opened this issue May 30, 2023 · 9 comments · May be fixed by #7116
Open

BUG: Jax-based samplers crash at transformation stage #6744

fonnesbeck opened this issue May 30, 2023 · 9 comments · May be fixed by #7116
Labels

Comments

@fonnesbeck
Copy link
Member

fonnesbeck commented May 30, 2023

Describe the issue:

The Jax-based samplers crash after sampling, following the "Transforming variables..." message on medium-to-large models (thousands of rows, hundreds of parameters). This occurs both on GPU and CPU systems, and using either the numpyro or blackjax samplers. The failure on GPU returns a backtrace that isolates the issue at the vmap in _postprocess_samples. On a CPU (MacBook Pro M1), the process is simply killed without any error messages. I have tried running the GPU model with the postprocessing_backend="cpu" argument for the numpyro sampler, but this does not seem to make a difference. Should it be using vmap when the postprocessing backend is CPU?

Reproduceable code example:

Will add example when I can come up with one

Error message:

CPU machine error:


Compilation time =  0:00:09.225151
Sampling...
Running chain 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Running chain 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Sampling time =  12:00:35.215191
Transforming variables...
Killed: 9
/Users/cfonnesbeck/mambaforge/envs/pymc/lib/python3.11/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

PyMC version information:

PyMC 5.3.0
PyTensor 2.11.1

Context for the issue:

The numpyro sampler is currently unusable for moderate-sized models due to this issue.

@fonnesbeck fonnesbeck added the bug label May 30, 2023
@fonnesbeck
Copy link
Member Author

Setting postprocessing_chunks to somewhat large values (~10) seems to prevent this, since it appears to be an issue with vmap.

@ricardoV94
Copy link
Member

I think this was solved by switching to scan as the default

@fonnesbeck
Copy link
Member Author

I'm still getting out of memory crashes after sampling even when using v5.10. Is it still possible to set postprocessing_chunks? It seemed to work previously.

@ricardoV94
Copy link
Member

The options are now scan or vmap, scan is the default which is more memory conscious:

postprocessing_vectorize: Literal["vmap", "scan"] = "scan",

@fonnesbeck
Copy link
Member Author

fonnesbeck commented Dec 12, 2023

Yeah, I saw that. I still get crashes post-processing on GPU for large models (even with postprocessing_backend="cpu").

@fonnesbeck
Copy link
Member Author

fonnesbeck commented Dec 12, 2023

This looks like it might help, though it is not implemented in Jax yet. We should probably keep the option for using xmap in the interim.

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 12, 2023

We are already using Scan by default, so I don't think it would help

@JasonTam
Copy link
Contributor

JasonTam commented Jan 22, 2024

I'm running into the same OOM issue in post-processing with the default postprocessing_vectorize="scan" .
Is postprocessing_chunks not something that can brought back as an experimental, use at your own risk, parameter?

@ricardoV94
Copy link
Member

IIRC postprocessing_chunks is just using scan under the hood anyway, so it shouldn't help. Can you check it actually helps in your case?

We need an example to investigate this issue, but if you see a difference we can consider temporarily reverting while we figure it out

@JasonTam JasonTam linked a pull request Jan 24, 2024 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants