-
-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use multi-threading instead of multi-processing #394
Comments
@svank - just out of curiosity, do you have any sense of what could be making the adaptive code not be thread-safe? |
My first thought was about how the Cython function calls back into Python-land for the coordinates calculations, but I guess that's just a GIL matter rather than a thread-safety thing (and wouldn't clearly cause the glitch your notebook shows). I'll have to play with it a bit |
I just played with this a little bit and found that if I pass Here's an expanded notebook: https://gist.github.com/svank/d63ef6bdf4e146577d7a78111ad85855 |
@svank yes I think I've come to the same conclusion that this could be in Interestingly https://www.atnf.csiro.au/people/mcalabre/WCS/wcslib/threads.html says that WCSLIB is basically thread-safe but https://www.gnu.org/software/gnuastro/manual/html_node/World-Coordinate-System.html#:~:text=The%20wcsprm%20structure%20of%20WCSLIB,the%20same%20wcsprm%20structure%20pointer. says: The wcsprm structure of WCSLIB is not thread-safe: you can’t use the same pointer on multiple threads. For example, if you use gal_wcs_img_to_world simultaneously on multiple threads, you shouldn’t pass the same wcsprm structure pointer. You can use gal_wcs_copy to keep and use separate copies the main structure within each thread, and later free the copies with gal_wcs_free. and I think we do use wcsprm so maybe it is related to that. cc @manodeep |
It is indeed a WCS issue - the third commit in #434 fixes the multi-threaded results (but presumably has a performance penalty) |
Yeah - just reading through the references that you posted, it definitely seemed likely that |
This commit to casacore adds mutexes for thread-unsafe WCS - however, those are targeting |
Small example to reproduce the bug with import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool
hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
hdu2 = fits.open(get_pkg_data_filename("galactic_center/gc_msx_e.fits"))[0]
wcs1 = WCS(hdu1.header)
wcs2 = WCS(hdu2.header)
N = 1_000_000
N_iter = 1
xp = np.random.randint(1, 100, N).astype(float).reshape((1000, 1000))
yp = np.random.randint(1, 100, N).astype(float).reshape((1000, 1000))
def repeated_transforms(xp, yp):
for i in range(N_iter):
xp, yp = pixel_to_pixel(wcs1, wcs2, xp, yp)
xp, yp = pixel_to_pixel(wcs2, wcs1, xp, yp)
return xp, yp
pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)
for xp2, yp2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
) |
The above outputs:
|
And now even simpler: import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool
hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
wcs = WCS(hdu1.header)
N = 1_000_000
N_iter = 1
xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)
def repeated_transforms(xp, yp):
for i in range(N_iter):
xw, yw = wcs.all_pix2world(xp, yp, 0)
wcs.wcs.lng # this access causes issues, without it all works
xp, yp = wcs.all_world2pix(xw, yw, 0)
return xp, yp
pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)
for xp2, yp2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
) it seems accessing |
And even simpler, this time using the conversion functions on Wcsprm directly: import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool
wcs = WCS(naxis=2)
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
def repeated_transforms(pixel):
world = wcs.wcs.p2s(pixel, 0)["world"]
wcs.wcs.lat
pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
return pixel
for n_threads in [1, 2, 8]:
print("N_threads:", n_threads)
pool = ThreadPool(n_threads)
results = pool.map(repeated_transforms, (pixel,) * n_threads)
for pixel2 in results:
print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}") gives:
Interestingly in this case, because the inner function is almost exclusively the C functions, it seems all the data gets corrupt. |
Just so I understand - adding that access to |
At the very least, this convincingly demonstrates (to me) that wcs is not threadsafe. Does the error go away if you make a copy of wcs within the function and use that copy? |
@manodeep - hmm in this example: it doesn't seem to actually matter, I can remove the access to It does seem to matter, and removing it fixes my issues. |
Hmm well now I'm puzzled, the following example also shows the issue, this is even if I make a whole new WCS object inside each thread: import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
def repeated_transforms(pixel):
wcs = WCS(naxis=2)
world = wcs.wcs.p2s(pixel, 0)["world"]
pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
return pixel
for n_threads in [0, 1, 2, 8]:
print("N_threads:", n_threads)
if n_threads == 0:
results = [repeated_transforms(pixel)]
else:
pool = ThreadPool(n_threads)
results = pool.map(repeated_transforms, (pixel,) * n_threads)
for pixel2 in results:
print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}") gives:
I wonder if I'm doing something wrong with the calls to p2s and s2p here as it's a bit suspicious that suddenly all values are different compared to earlier examples, but maybe it's also because a higher fraction of time is spent in C code. It's also weird that the issue persists here even when creating a new WCS object inside each thread. |
Having said that, maybe it is indeed highlighting the issue, if I switch to a process-based Pool the issue goes away: import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import Pool
def repeated_transforms(pixel):
wcs = WCS(naxis=2)
world = wcs.wcs.p2s(pixel, 0)["world"]
pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
return pixel
def main():
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
for n_proc in [0, 1, 2, 8]:
print("N_processes:", n_proc)
if n_proc == 0:
results = [repeated_transforms(pixel)]
else:
pool = Pool(n_proc)
results = pool.map(repeated_transforms, (pixel,) * n_proc)
for pixel2 in results:
print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")
if __name__ == "__main__":
main()
|
Ok well on that basis maybe #394 (comment) is the best example to go with to reproduce the issue? (so the access to |
I am so confused by this sample - how can a race condition occur if you are creating a new Is there a typo in this sample - the for loop variable is called |
FWIW the issue is definitely in WCSLIB - if I edit the |
yes sorry I tried editing the code in the comment directly to change thread -> proc but clearly failed (fixed now) |
I am very confused by this too, and this would suggest perhaps that there is some kind of global variable in WCSLIB that is being accessed and modified by different threads? |
For fun, I tried checking what the actual offset between expected and actual pixel positions is, by doing:
which gives:
So seems to be off by small integer values, which seem to be related to the number of threads running concurrently. |
Ah, changing the second argument of |
Hmm I wonder if it's something dumb like that |
Ok so I think there must be two separate issues, because now if I go back to one of the earlier examples and change the origin argument in that to be 1, there is still a problem: import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool
hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
wcs = WCS(hdu1.header)
N = 1_000_000
N_iter = 1
xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)
def repeated_transforms(xp, yp):
for i in range(N_iter):
xw, yw = wcs.all_pix2world(xp, yp, 1)
wcs.wcs.lng # this access causes issues, without it all works
xp, yp = wcs.all_world2pix(xw, yw, 1)
return xp, yp
pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)
for xp2, yp2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
) and in this case the offsets between expected and actual positions are not integers. I think the off-by-one issue above might not be the one we are running into in reproject because it doesn't seem to be triggered when calling |
Ok, so to summarize:
For the second case, I've now managed to make an example that replicates the issue but just calling import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool
wcs = WCS(naxis=2)
wcs.wcs.crpix = [-234.75, 8.3393]
wcs.wcs.cdelt = np.array([-0.066667, 0.066667])
wcs.wcs.crval = [0, -90]
wcs.wcs.ctype = ["RA---AIR", "DEC--AIR"]
wcs.wcs.set()
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
def repeated_transforms(pixel):
world = wcs.wcs.p2s(pixel.copy(), 1)['world']
wcs.wcs.lng # this access causes issues, without it all works
pixel = wcs.wcs.s2p(world, 1)['pixcrd']
return pixel
pool = ThreadPool(8)
results = pool.map(repeated_transforms, (pixel,) * 8)
for pixel2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(pixel, pixel2))}"
) gives:
so I think this is the example we should use going forward. The key is that it does a copy of the array before calling |
There's might be multiple issues at play here: i) thread race condition within wcslib (which somehow manifests when there is an access to a memory location (!) and ii) including the fits-origin fix for the input array and then fixing afterwards I still don't understand why |
Plus, why is there such a big discrepancy in the number of wrong pixel values between the (mostly)-C-functions (all values are wrong) and the more-python-functions test cases (~100 values are wrong). I will also check that both the multi-process and multi-thread versions are actually splitting up the work. For example, if the multi-process is spawning (don't know why) 8 copies of the same task in serial - then it stands to reason that the 8 (identical) serial task would produce correct results. |
@manodeep - yes, as mentioned in this summary I think there are two separate issues and I think we should focus on the second one I mention (2.) which I think is what you call i) Just to be clear, |
Just to be clear, in both the threaded and multiprocess cases, the MWEs above should result in the same exact code being run on all threads or processes, i.e. the work of converting the coordinates is not being spit up into chunks for example. This is intentional just for debugging purposes, to check that all threads find the same results (which they don't). It's not surprising that the multiprocessing version works though, it is just a sanity check. In reproject, each thread or process is actually computing something different so should be more efficient than running in serial. |
Ahh I see - the origin-fix-unfix caused the all failures in C. But even if the input array is being modified in-place, I don't see why the access to It might be that this shows up on threads but not on processes is a side-effect of copy-on-write (c-o-w) for the forked processes - whenever a forked process modifies the wcs struct, the os creates a copy of the entire wcs struct and then writes to that specific memory location. However, that c-o-w protection is not there for threads and hence there is a race condition. Still does not quite explain why the access is |
It's not required to uncover the origin fix-unfix issue. It's only required for issue 2 mentioned in #394 (comment) |
@manodeep - just to be clear, ignore the origin fix/unfix issue for now (issue 1 in my summary comment). I think I know a way forward to fix that. We should just focus on the more complex second issue. To be clear, this is the issue where only 10s or 100s of values are wrong, and which requires the |
Yup - I understand the errors and focussing only on #2. But from a OS-/python-level, I don't understand why that |
Would it be helpful if I open two separate issues in the astropy core repo since we know the problem is now there not in reproject, as otherwise I think there's too much potential for confusion in this giant thread here? |
Sure thing. The issue is clearly in astropy.wcs / WCSLIB and not in reproject |
Yes completely agree that this is puzzling! And not only that but if you try and access e.g. |
Ok, will do on Wednesday (for me!) |
I've opened two issues over at astropy:
going forward we should probably discuss them there, in the individual issues so as not to get confused. |
The main functions in reproject, such as
reproject_interp
, acceptparallel=
andblock_size=
arguments which, if used, will leverage dask behind the scenes to split up the data into chunks and then use a multi-processing scheduler to distribute the work.Ideally we should be using multi-threading instead of multi-processing, but currently we don't because there appear to be some issues with some output pixels not having the right value when using multi-threading.
Fixing this will provide two main benefits:
return_type='dask'
then use the default scheduler to compute the arrayHere's a notebook illustrating the issues: https://gist.github.com/astrofrog/e8808ee3ee8b7b86a979e0cb305d518b - note that while there is no explicit mention of threads anywhere, in the
compare
functionarray2
is a dask array and when it gets passed to Matplotlib,.compute()
gets called and the default scheduler uses threads.At the moment all algorithms seem to have issues though all appear to be different. Note that for adaptive I sometimes have to run a few times to see issues.
The text was updated successfully, but these errors were encountered: