Skip to content

Commit

Permalink
Started implemented parallelization along broadcasted dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Sep 14, 2023
1 parent eabeb23 commit 31ac87c
Showing 1 changed file with 84 additions and 56 deletions.
140 changes: 84 additions & 56 deletions reproject/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ def _reproject_dispatcher(
if reproject_func_kwargs is None:
reproject_func_kwargs = {}

# Determine whether any broadcasting is taking place
broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out)

# Determine whether block size indicates we should parallelize over broadcasted dimension
broadcasted_parallelization = False
if broadcasting and block_size:
if len(block_size) == len(shape_out):
if (
block_size[-wcs_in.low_level_wcs.pixel_n_dim :]
== shape_out[-wcs_in.low_level_wcs.pixel_n_dim :]
):
broadcasted_parallelization = True
block_size = (

Check warning on line 116 in reproject/common.py

View check run for this annotation

Codecov / codecov/patch

reproject/common.py#L115-L116

Added lines #L115 - L116 were not covered by tests
block_size[: -wcs_in.low_level_wcs.pixel_n_dim]
+ (-1,) * wcs_in.low_level_wcs.pixel_n_dim
)

# We set up a global temporary directory since this will be used e.g. to
# store memory mapped Numpy arrays and zarr arrays.

Expand Down Expand Up @@ -158,32 +175,6 @@ def _reproject_dispatcher(

shape_in = array_in.shape

# When in parallel mode, we want to make sure we avoid having to copy the
# input array to all processes for each chunk, so instead we write out
# the input array to a Numpy memory map and load it in inside each process
# as a memory-mapped array. We need to be careful how this gets passed to
# reproject_single_block so we pass a variable that can be either a string
# or the array itself (for synchronous mode). If the input array is a dask
# array we should always write it out to a memmap even in synchronous mode
# otherwise map_blocks gets confused if it gets two dask arrays and tries
# to iterate over both.

if isinstance(array_in, da.core.Array) or parallel:
# If return_type=='dask',
if return_type == "dask":
# We should use a temporary directory that will persist beyond
# the call to the reproject function.
tmp_dir = tempfile.mkdtemp()
else:
tmp_dir = local_tmp_dir
array_in_or_path = as_delayed_memmap_path(array_in, tmp_dir)
else:
# Here we could set array_in_or_path to array_in_path if it
# has been set previously, but in synchronous mode it is better to
# simply pass a reference to the memmap array itself to avoid having
# to load the memmap inside each reproject_single_block call.
array_in_or_path = array_in

def reproject_single_block(a, array_or_path, block_info=None):
if a.ndim == 0 or block_info is None or block_info == []:
return np.array([a, a])
Expand Down Expand Up @@ -217,40 +208,77 @@ def reproject_single_block(a, array_or_path, block_info=None):

return np.array([array, footprint])

# NOTE: the following array is just used to set up the iteration in map_blocks
# but isn't actually used otherwise - this is deliberate.

if block_size is not None and block_size != "auto":
if wcs_in.low_level_wcs.pixel_n_dim < len(shape_out):
if len(block_size) < len(shape_out):
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
else:
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
if block_size[i] != -1 and block_size[i] != shape_out[i]:
raise ValueError(
"block shape for extra broadcasted dimensions should cover entire array along those dimensions"
)
if broadcasted_parallelization:
array_out_dask = da.empty(shape_out, chunks=block_size)
array_in = array_in.rechunk(block_size)

Check warning on line 213 in reproject/common.py

View check run for this annotation

Codecov / codecov/patch

reproject/common.py#L213

Added line #L213 was not covered by tests

result = da.map_blocks(

Check warning on line 215 in reproject/common.py

View check run for this annotation

Codecov / codecov/patch

reproject/common.py#L215

Added line #L215 was not covered by tests
reproject_single_block,
array_out_dask,
array_in,
dtype=float,
new_axis=0,
chunks=(2,) + array_out_dask.chunksize,
)

else:
if wcs_in.low_level_wcs.pixel_n_dim < len(shape_out):
chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
rechunk_kwargs = {"chunks": chunks}
# When in parallel mode, we want to make sure we avoid having to copy the
# input array to all processes for each chunk, so instead we write out
# the input array to a Numpy memory map and load it in inside each process
# as a memory-mapped array. We need to be careful how this gets passed to
# reproject_single_block so we pass a variable that can be either a string
# or the array itself (for synchronous mode). If the input array is a dask
# array we should always write it out to a memmap even in synchronous mode
# otherwise map_blocks gets confused if it gets two dask arrays and tries
# to iterate over both.

if isinstance(array_in, da.core.Array) or parallel:
# If return_type=='dask',
if return_type == "dask":
# We should use a temporary directory that will persist beyond
# the call to the reproject function.
tmp_dir = tempfile.mkdtemp()
else:
tmp_dir = local_tmp_dir
array_in_or_path = as_delayed_memmap_path(array_in, tmp_dir)
else:
rechunk_kwargs = {}
array_out_dask = da.empty(shape_out)
array_out_dask = array_out_dask.rechunk(
block_size_limit=8 * 1024**2, **rechunk_kwargs
)
# Here we could set array_in_or_path to array_in_path if it
# has been set previously, but in synchronous mode it is better to
# simply pass a reference to the memmap array itself to avoid having
# to load the memmap inside each reproject_single_block call.
array_in_or_path = array_in

if block_size is not None and block_size != "auto":
if broadcasting:
if len(block_size) < len(shape_out):
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
else:
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
if block_size[i] != -1 and block_size[i] != shape_out[i]:
raise ValueError(
"block shape for extra broadcasted dimensions should cover entire array along those dimensions"
)
array_out_dask = da.empty(shape_out, chunks=block_size)
else:
if broadcasting:
chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
rechunk_kwargs = {"chunks": chunks}
else:
rechunk_kwargs = {}
array_out_dask = da.empty(shape_out)
array_out_dask = array_out_dask.rechunk(
block_size_limit=8 * 1024**2, **rechunk_kwargs
)

result = da.map_blocks(
reproject_single_block,
array_out_dask,
array_in_or_path,
dtype=float,
new_axis=0,
chunks=(2,) + array_out_dask.chunksize,
)
result = da.map_blocks(
reproject_single_block,
array_out_dask,
array_in_or_path,
dtype=float,
new_axis=0,
chunks=(2,) + array_out_dask.chunksize,
)

# Ensure that there are no more references to Numpy memmaps
array_in = None
Expand Down

0 comments on commit 31ac87c

Please sign in to comment.