Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dask-delayed raster subsample(), reproject() and interp_points() #537

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

rhugonnet
Copy link
Contributor

@rhugonnet rhugonnet commented Apr 25, 2024

Summary

(Moved from GlacioHack/xdem#508. Some functions were written with help from @ameliefroessl, after discussions that stemmed in GlacioHack/xdem#501.)

This PR adds three out-of-memory raster operations:

  1. Subsampling (that respects nodata values),
  2. Georeferenced reprojection (inspired by Dask based implemenetation of reproject opendatacube/odc-geo#88) and
  3. Point interpolation, which harnesses the fact that rasters are NOT on a rectilinear grid but on a regular grid, to greatly reduce memory usage of xarray.interp().

Summary of reasoning behind implementations

For both subsampling and point interpolation, we had to deal with the difficulty of ragged outputs (varying 1D length, and different than the 2D input chunks). See this blogpost for a nice example of why this can be tricky: https://blog.dask.org/2021/07/02/ragged-output, and the issue it originated from: dask/dask#7589.
The map_blocks(drop_axis=) solution proposed in the blogpost unfortunately loads a lot of the chunks in memory at once by dropping the axis (rechunking in 1D essentially). Hence why we inspired ourselves from the dask.delayed solution instead.

For subsampling: Using dask.vindex (or vectorized Xarray indexing) allows to do pointwise sampling, but we cannot know the nodata values in advance.
In this PR, we compute the number of valid values per chunk in advance, then create a random subsample of the total number of valid values for the array (no matter which chunk they fall in), then identify which chunks they fall in, and sample them out-of-memory per chunk.
The only downside of the method is that the random subsample is chunksize-dependent (for a fixed random state, it will give a different random subsample for different chunk sizes). If the user has multiple arrays to sample from at the same points, the argument return_indices can solve this. So this seems fine as long as it is clear in the function description.

For interpolation:
Interpolation in the chunked dimensions was recently added to Xarray: pydata/xarray#4155, but the memory usage is really big pydata/xarray#6799 (we checked this ourselves too, see GlacioHack/xdem#501 (reply in thread)), because the grid can be rectilinear.
In this PR, we rely on the fact that a raster grid is of equal length in each dimension, which allows to both map the location of points to interpolate using less memory and, for per-chunk interpolation, only need to pass a couple values (step in X/Y and starting X/Y) instead of a full vector of coordinate.
The memory usage is bigger than a single chunk because we transform the input delayed array into overlapping arrays (with a depth of overlap depending on the interpolation method), so overlapping chunks have to be loaded as well. This is fine as long as the chunking is small enough.

For reprojection:
The only implementation of this to my knowledge is in opendatacube/odc-geo#88 (and maybe also in pyresample and geowombat?), but it depends on a long class of classes in odc-geo (starting with GeoBox and VariableSizedTiles). As it didn't seem like so much work, I made a concise stand-alone implementation with the idea that it could potentially be moved to Rioxarray, or be used in GeoUtils's Xarray accessor without adding any new dependency.

Possible contribution to upstream packages

I will start discussions in Rioxarray/Xarray/Dask to see if it makes sense to integrate some there directly, I was thinking:

  • Rioxarray for reproject(), but would requires shapely as optional dependency with current implementation;
  • Xarray for interp(), on the assumptions that the grid is not rectilinear but equal/regular along the interpolated dimensions, and would have to extend to N-D!
  • Dask for subsample(): to mirror their dask.dataframe.sample() functionality, while dask.array doesn't have one.

Next steps

This is a big step forward for #383.
We should not describe too much these functions in the API just yet, but wait until the Xarray accessor is out (as it will facilitate greatly reading the chunked array properly through Rioxarray).

TODO

  • Furnish tests a bit more,
  • Split tests into two: memory checks (big dataset) and output accuracy (small datasets),
  • Use https://github.com/itamarst/dask-memusage in test_delayed to consistently check out-of-memory operation work as intended,
  • Find out why rio.warp.reproject() fails with an error during multi-threading only? (Forced to single-thread)
  • Make output tests run faster by not relying on the cluster fixture for all tests.

@rhugonnet
Copy link
Contributor Author

rhugonnet commented Apr 27, 2024

@adehecq @ameliefroessl I think this PR is essentially finished and ready for your review 😄. (EDIT: Adding @erikmannerfelt too, as this joins some of the variete work he is doing!)
No need for API or doc changes for now, this would come along with the Xarray accessor which would wrap these functionalities (#383, #446), as described above. I tried to test things as exhaustively as possible for a controlled input, checks on input types and raised errors/warnings will live elsewhere with the wrapper functions (so in a later PR).

For tests, monitoring the memory usage during the .compute() call was especially tricky, but I'm happy we have quantitative checks that ensure the memory usage remains low (for potential future changes)! It works decently well with the Dask memusage plugin, the drawback is the errors raised during teardown (closing) of the Dask client. I silence some but others remain in the background and pollute a bit the pytest output. I looked around quite a bit (especially dask/distributed#3540 is interesting) and couldn't find a more elegant solution... I think the Plugin class of dask-memusage would have to be modified to close properly, but this is out of my expertise (I tried a bit without success, as it's a very short code: https://github.com/itamarst/dask-memusage/blob/master/dask_memusage.py).

Finally, the remaining issue is the discrepancy comparing the delayed reproject and the one on the full array. I thought that was an error in my implementation first, but after a lot of digging it seems related to rasterio.warp.reproject that doesn't give the same output when reprojecting to a different subset of a raster, even if aligned with the original output raster and with the exact same input raster. The errors can get quite big for bilinear/cubic (it's a bit worrying for some analysis, could mess up coregistration quite a bit for instance), I hope I made a mistake somewhere.
I continued an issue and added a reproducible example here: rasterio/rasterio#2052.

I have also opened issues or continued discussions in Rioxarray (maybe moving the Dask reprojection there? corteva/rioxarray#119 (comment)), Xarray (maybe covering the case of regular grid interpolation directly there? pydata/xarray#6799 (comment)) and Dask (maybe having a dask.array.sample(..., ignore_nan=) function mirroring their dask.dataframe.sample()? dask/dask#11077).

@rhugonnet rhugonnet requested review from erikmannerfelt and removed request for ameliefroessl April 27, 2024 21:46
Copy link

@ameliefroessl ameliefroessl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice tests!! Looking very good :)

geoutils/raster/delayed.py Outdated Show resolved Hide resolved
geoutils/raster/delayed.py Outdated Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Outdated Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Show resolved Hide resolved
geoutils/raster/delayed.py Outdated Show resolved Hide resolved
geoutils/raster/delayed.py Outdated Show resolved Hide resolved

# Map depth of overlap required for each interpolation method
# TODO: Double-check this window somewhere in SciPy's documentation
map_depth = {"nearest": 1, "linear": 2, "cubic": 3, "quintic": 5}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this should be defined outside the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why, it's not needed for tests, and I don't see any other use?

geoutils/raster/delayed.py Outdated Show resolved Hide resolved
geoutils/raster/delayed.py Outdated Show resolved Hide resolved
geoutils/raster/delayed.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants