Skip to content

Commit

Permalink
Started implemented parallelization along broadcasted dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Jul 14, 2023
1 parent 1b8b7c0 commit f999928
Showing 1 changed file with 78 additions and 46 deletions.
124 changes: 78 additions & 46 deletions reproject/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ def _reproject_dispatcher(
if reproject_func_kwargs is None:
reproject_func_kwargs = {}

# Determine whether any broadcasting is taking place
broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out)

# Determine whether block size indicates we should parallelize over broadcasted dimension
broadcasted_parallelization = False
if broadcasting and block_size:
if len(block_size) == len(shape_out):
if (
block_size[-wcs_in.low_level_wcs.pixel_n_dim :]
== shape_out[-wcs_in.low_level_wcs.pixel_n_dim :]
):
broadcasted_parallelization = True
block_size = (

Check warning on line 112 in reproject/common.py

View check run for this annotation

Codecov / codecov/patch

reproject/common.py#L111-L112

Added lines #L111 - L112 were not covered by tests
block_size[: -wcs_in.low_level_wcs.pixel_n_dim]
+ (-1,) * wcs_in.low_level_wcs.pixel_n_dim
)

# We set up a global temporary directory since this will be used e.g. to
# store memory mapped Numpy arrays and zarr arrays.

Expand Down Expand Up @@ -154,28 +171,10 @@ def _reproject_dispatcher(

shape_in = array_in.shape

# When in parallel mode, we want to make sure we avoid having to copy the
# input array to all processes for each chunk, so instead we write out
# the input array to a Numpy memory map and load it in inside each process
# as a memory-mapped array. We need to be careful how this gets passed to
# reproject_single_block so we pass a variable that can be either a string
# or the array itself (for synchronous mode). If the input array is a dask
# array we should always write it out to a memmap even in synchronous mode
# otherwise map_blocks gets confused if it gets two dask arrays and tries
# to iterate over both.

if isinstance(array_in, da.core.Array) or parallel:
array_in_or_path = as_delayed_memmap_path(array_in, tmp_dir)
else:
# Here we could set array_in_or_path to array_in_path if it
# has been set previously, but in synchronous mode it is better to
# simply pass a reference to the memmap array itself to avoid having
# to load the memmap inside each reproject_single_block call.
array_in_or_path = array_in

def reproject_single_block(a, array_or_path, block_info=None):
if a.ndim == 0 or block_info is None or block_info == []:
return np.array([a, a])

slices = [slice(*x) for x in block_info[None]["array-location"][-wcs_out.pixel_n_dim :]]

if isinstance(wcs_out, BaseHighLevelWCS):
Expand Down Expand Up @@ -209,37 +208,70 @@ def reproject_single_block(a, array_or_path, block_info=None):
# NOTE: the following array is just used to set up the iteration in map_blocks
# but isn't actually used otherwise - this is deliberate.

if block_size:
if wcs_in.low_level_wcs.pixel_n_dim < len(shape_out):
if len(block_size) < len(shape_out):
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
else:
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
if block_size[i] != -1 and block_size[i] != shape_out[i]:
raise ValueError(
"block shape for extra broadcasted dimensions should cover entire array along those dimensions"
)
if broadcasted_parallelization:
array_out_dask = da.empty(shape_out, chunks=block_size)
array_in = array_in.rechunk(block_size)

Check warning on line 213 in reproject/common.py

View check run for this annotation

Codecov / codecov/patch

reproject/common.py#L213

Added line #L213 was not covered by tests

result = da.map_blocks(

Check warning on line 215 in reproject/common.py

View check run for this annotation

Codecov / codecov/patch

reproject/common.py#L215

Added line #L215 was not covered by tests
reproject_single_block,
array_out_dask,
array_in,
dtype=float,
new_axis=0,
chunks=(2,) + array_out_dask.chunksize,
)

else:
if wcs_in.low_level_wcs.pixel_n_dim < len(shape_out):
chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
rechunk_kwargs = {"chunks": chunks}
# When in parallel mode, we want to make sure we avoid having to copy the
# input array to all processes for each chunk, so instead we write out
# the input array to a Numpy memory map and load it in inside each process
# as a memory-mapped array. We need to be careful how this gets passed to
# reproject_single_block so we pass a variable that can be either a string
# or the array itself (for synchronous mode). If the input array is a dask
# array we should always write it out to a memmap even in synchronous mode
# otherwise map_blocks gets confused if it gets two dask arrays and tries
# to iterate over both.

if isinstance(array_in, da.core.Array) or parallel:
array_in_or_path = as_delayed_memmap_path(array_in, tmp_dir)
else:
rechunk_kwargs = {}
array_out_dask = da.empty(shape_out)
array_out_dask = array_out_dask.rechunk(
block_size_limit=8 * 1024**2, **rechunk_kwargs
)
# Here we could set array_in_or_path to array_in_path if it
# has been set previously, but in synchronous mode it is better to
# simply pass a reference to the memmap array itself to avoid having
# to load the memmap inside each reproject_single_block call.
array_in_or_path = array_in

if block_size:
if broadcasting:
if len(block_size) < len(shape_out):
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
else:
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
if block_size[i] != -1 and block_size[i] != shape_out[i]:
raise ValueError(
"block shape for extra broadcasted dimensions should cover entire array along those dimensions"
)
array_out_dask = da.empty(shape_out, chunks=block_size)
else:
if broadcasting:
chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
rechunk_kwargs = {"chunks": chunks}
else:
rechunk_kwargs = {}
array_out_dask = da.empty(shape_out)
array_out_dask = array_out_dask.rechunk(
block_size_limit=8 * 1024**2, **rechunk_kwargs
)

result = da.map_blocks(
reproject_single_block,
array_out_dask,
array_in_or_path,
dtype=float,
new_axis=0,
chunks=(2,) + array_out_dask.chunksize,
)
result = da.map_blocks(
reproject_single_block,
array_out_dask,
array_in_or_path,
dtype=float,
new_axis=0,
chunks=(2,) + array_out_dask.chunksize,
)

# Ensure that there are no more references to Numpy memmaps
array_in = None
Expand Down

0 comments on commit f999928

Please sign in to comment.