Skip to content

Commit

Permalink
TYPE: Standardize resampling type (#1571)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Apr 2, 2024
1 parent 2a4aefe commit 9561764
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 16 deletions.
9 changes: 5 additions & 4 deletions datacube/api/core.py
Expand Up @@ -18,6 +18,7 @@
from datacube.storage import reproject_and_fuse, BandInfo
from datacube.utils import ignore_exceptions_if
from odc.geo import CRS, yx_, res_, resyx_, Resolution, XY
from odc.geo.warp import Resampling
from odc.geo.xr import xr_coords
from datacube.utils.dates import normalise_dt
from odc.geo.geom import intersects, box, bbox_union, Geometry
Expand Down Expand Up @@ -244,7 +245,7 @@ def load(self,
measurements: str | list[str] | None = None,
output_crs: Any = None,
resolution: int | float | tuple[int | float, int | float] | Resolution | None = None,
resampling: str | dict[str, str] | None = None,
resampling: Resampling | dict[str, Resampling] | None = None,
align: XY[float] | Iterable[float] | None = None,
skip_broken_datasets: bool = False,
dask_chunks: dict[str, str | int] | None = None,
Expand Down Expand Up @@ -878,7 +879,7 @@ def _cbk(*ignored):
@staticmethod
def load_data(sources: xarray.DataArray, geobox: GeoBox,
measurements: Mapping[str, Measurement] | list[Measurement],
resampling: str | dict[str, str] | None = None,
resampling: Resampling | dict[str, Resampling] | None = None,
fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None,
dask_chunks: dict[str, str | int] | None = None,
skip_broken_datasets: bool = False,
Expand Down Expand Up @@ -969,7 +970,7 @@ def __exit__(self, type_, value, traceback):


def per_band_load_data_settings(measurements: list[Measurement] | Mapping[str, Measurement],
resampling: str | Mapping[str, str] | None = None,
resampling: Resampling | Mapping[str, Resampling] | None = None,
fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None
) -> list[Measurement]:
def with_resampling(m, resampling, default=None):
Expand All @@ -982,7 +983,7 @@ def with_fuser(m, fuser, default=None):
m['fuser'] = fuser.get(m.name, default)
return m

if isinstance(resampling, str):
if resampling is not None and not isinstance(resampling, dict):
resampling = {'*': resampling}

if fuse_func is None or callable(fuse_func):
Expand Down
3 changes: 2 additions & 1 deletion datacube/storage/_load.py
Expand Up @@ -23,6 +23,7 @@
from odc.geo.geobox import GeoBox
from odc.geo.roi import roi_is_empty
from odc.geo.xr import xr_coords
from odc.geo.warp import Resampling
from datacube.model import Measurement
from datacube.drivers._types import ReaderDriver
from ..drivers.datasource import DataSource
Expand All @@ -47,7 +48,7 @@ def reproject_and_fuse(datasources: List[DataSource],
destination: np.ndarray,
dst_geobox: GeoBox,
dst_nodata: Optional[Union[int, float]],
resampling: str = 'nearest',
resampling: Resampling = 'nearest',
fuse_func: Optional[FuserFunction] = None,
skip_broken_datasets: bool = False,
progress_cbk: Optional[ProgressFunction] = None,
Expand Down
10 changes: 7 additions & 3 deletions datacube/utils/cog.py
Expand Up @@ -18,6 +18,7 @@
from .io import check_write_path
from odc.geo.geobox import GeoBox
from odc.geo.math import align_up
from odc.geo.warp import Resampling, resampling_s2rio

from deprecat import deprecat

Expand All @@ -38,7 +39,7 @@ def _write_cog(
nodata: Optional[float] = None,
overwrite: bool = False,
blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
overview_resampling: Optional[Resampling] = None,
overview_levels: Optional[List[int]] = None,
ovr_blocksize: Optional[int] = None,
use_windowed_writes: bool = False,
Expand Down Expand Up @@ -118,7 +119,10 @@ def _write_cog(
fname, overwrite
) # aborts if overwrite=False and file exists already

resampling = rasterio.enums.Resampling[overview_resampling]
if isinstance(overview_resampling, str):
resampling = resampling_s2rio(overview_resampling)
else:
resampling = overview_resampling

if (blocksize % 16) != 0:
warnings.warn("Block size must be a multiple of 16, will be adjusted")
Expand Down Expand Up @@ -219,7 +223,7 @@ def write_cog(
overwrite: bool = False,
blocksize: Optional[int] = None,
ovr_blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
overview_resampling: Optional[Resampling] = None,
overview_levels: Optional[List[int]] = None,
use_windowed_writes: bool = False,
intermediate_compression: Union[bool, str, Dict[str, Any]] = False,
Expand Down
3 changes: 3 additions & 0 deletions docs/about/whats_new.rst
Expand Up @@ -8,6 +8,9 @@ What's New
v1.9.next
=========

- Standardize resampling input supported to `odc.geo.warp.Resampling`.


v1.9.0-rc3 (27th March 2024)
============================

Expand Down
28 changes: 20 additions & 8 deletions tests/storage/test_storage_read.py
Expand Up @@ -4,6 +4,9 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np

import pytest
from rasterio.enums import Resampling

from datacube.storage._read import (
read_time_slice,
read_time_slice_v2,
Expand All @@ -27,14 +30,20 @@
)


nearest_resampling_parametrize = pytest.mark.parametrize(
"nearest_resampling", ['nearest', Resampling.nearest, Resampling.nearest.value]
)


def test_pick_read_scale():
assert pick_read_scale(0.7) == 1
assert pick_read_scale(1.3) == 1
assert pick_read_scale(2.3) == 2
assert pick_read_scale(1.99999) == 2


def test_read_paste(tmpdir):
@nearest_resampling_parametrize
def test_read_paste(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from pathlib import Path
Expand All @@ -46,7 +55,7 @@ def test_read_paste(tmpdir):

mm = write_gtiff(pp/'tst-read-paste-128x64-int16.tif', xx, nodata=None)

def _read(geobox, resampling='nearest',
def _read(geobox, resampling=nearest_resampling,
fallback_nodata=-999,
dst_nodata=-999,
check_paste=False):
Expand Down Expand Up @@ -112,7 +121,8 @@ def _read(geobox, resampling='nearest',
np.testing.assert_array_equal(xx[1::2, 1::2], yy)


def test_read_with_reproject(tmpdir):
@nearest_resampling_parametrize
def test_read_with_reproject(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from pathlib import Path
Expand All @@ -131,7 +141,7 @@ def test_read_with_reproject(tmpdir):
assert mm.geobox == tile

def _read(geobox,
resampling='nearest',
resampling=nearest_resampling,
fallback_nodata=None,
dst_nodata=-999):
with RasterFileDataSource(mm.path, 1, nodata=fallback_nodata).open() as rdr:
Expand Down Expand Up @@ -171,7 +181,8 @@ def _read(geobox,
assert nvalid > nempty


def test_read_paste_v2(tmpdir):
@nearest_resampling_parametrize
def test_read_paste_v2(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from datacube.testutils.iodriver import open_reader
Expand All @@ -184,7 +195,7 @@ def test_read_paste_v2(tmpdir):

mm = write_gtiff(pp/'tst-read-paste-128x64-int16.tif', xx, nodata=None)

def _read(geobox, resampling='nearest',
def _read(geobox, resampling=nearest_resampling,
fallback_nodata=-999,
dst_nodata=-999,
check_paste=False):
Expand Down Expand Up @@ -256,7 +267,8 @@ def _read(geobox, resampling='nearest',
np.testing.assert_array_equal(xx[1::2, 1::2], yy)


def test_read_with_reproject_v2(tmpdir):
@nearest_resampling_parametrize
def test_read_with_reproject_v2(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from datacube.testutils.iodriver import open_reader
Expand All @@ -268,7 +280,7 @@ def test_read_with_reproject_v2(tmpdir):
assert (xx != -999).all()
tile = AlbersGS.tile_geobox((17, -40))[:64, :128]

def _read(geobox, resampling='nearest',
def _read(geobox, resampling=nearest_resampling,
fallback_nodata=-999,
dst_nodata=-999):

Expand Down

0 comments on commit 9561764

Please sign in to comment.