Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreading issue with palantir.utils.run_magic_imputation() #133

Open
gighuarhguggg45 opened this issue Feb 12, 2024 · 8 comments
Open

Comments

@gighuarhguggg45
Copy link

Hi,

Thank you for your work with Palantir. I have been running into issues with
imputed_X = palantir.utils.run_magic_imputation(ad,n_jobs=16)

(or any n_jobs > 1).

I get the following warning shortly after the run starts, which always ends up with the python kernel dying after a while.

`RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock

I do not get any errors with:
imputed_X = palantir.utils.run_magic_imputation(ad,n_jobs=1)

But n_jobs=1 just runs forever and never produces a result with my large scATAC anndata (101,966 cells and 228,892 features, with 30,000 variable features).

Of note, when I run the tutorial Palantir analysis notebook, I still get the warning, but I do not get the crash.

I am using 80 vCPUs and 640GB of memory. I'm using python 3.9.2

Thank you.

@katosh
Copy link
Collaborator

katosh commented Feb 13, 2024

Hi @gighuarhguggg45! Thank you for reporting. This sounds like your ad.X might be a JAX array. Could you try

ad.X = np.asarray(ad.X)
imputed_X = palantir.utils.run_magic_imputation(ad, n_jobs=16)

and see if this fixes the problem? Otherwise, I would need some example data to reproduce this.

@rishikanthc
Copy link

I get the same error when I use pytorch dataloader with num_workers > 0. Looks like pytorch dataloader uses os.fork() which is conflicting with jax.

@katosh
Copy link
Collaborator

katosh commented Feb 22, 2024

Just to clarify: We do not explicitly use jax in the palantir.utils.run_magic_imputation function. However, we do use the dot product function that picks an implementation based on the input arrays. Jax is only being used if one of the input arrays (ad.X or ad.obsp["DM_Similarity"]) is a jax array.

Please let me know if the solution suggested above fixes the problem.

@wbrett87
Copy link

wbrett87 commented Mar 1, 2024

I get the same problem, and the solution you suggested does not work.

image

@katosh
Copy link
Collaborator

katosh commented Mar 1, 2024

@wbrett87 can you please inspect adata.X and see if the shape, format and content is what we expect from a gene expression matrix?

@wbrett87
Copy link

wbrett87 commented Mar 1, 2024

image

@katosh
Copy link
Collaborator

katosh commented Mar 2, 2024

@wbrett87, it appears that your ad.X is stored in a sparse format. Consequently, converting ad.X directly to a NumPy array is unnecessary and likely the cause of the error you're experienced above. That said, it's unexpected that you're encountering a JAX-related error upon executing palantir.utils.run_magic_imputation(ad, n_jobs=16) directly. This issue isn't something I've been able to replicate on my end. Another possibility could be that ad.obsp["DM_Similarity"] might be in a JAX array format, although this scenario would be unusual. For a more targeted investigation, could you provide a simplified code snippet or dataset that replicates the error? This would greatly aid in diagnosing and resolving the problem more efficiently.

@wbrett87
Copy link

wbrett87 commented Mar 4, 2024

ad.obsp["DM_Similarity"] is not a JAX array. I ran with one job as a temporary workaround. I will provide a simplified code snippet at some point soon when I have a bit more time. Thanks for your attention to this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants