Skip to content

Commit

Permalink
Add tests for Raster.get_mask (#529)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Mar 28, 2024
1 parent 02692f5 commit 914ddf3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
10 changes: 5 additions & 5 deletions geoutils/raster/raster.py
Expand Up @@ -2136,14 +2136,14 @@ def get_nanarray(self, return_mask: bool = False) -> NDArrayNum | tuple[NDArrayN

def get_mask(self) -> NDArrayBool:
"""
Get mask from the raster.
The mask is always returned as a boolean array, even if there is no mask to .data and thus .data.mask = a
single False value, nomask property of masked arrays.
Get mask of invalid values from the raster.
If the raster is not loaded, reads only the mask from disk to optimize memory usage.
:return:
The mask is always returned as a boolean array, even if there is no mask associated to .data (nomask property
of masked arrays).
:return: The mask of invalid values in the raster.
"""
# If it is loaded, use NumPy's getmaskarray function to deal with False values
if self.is_loaded:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_raster.py
Expand Up @@ -357,6 +357,24 @@ def test_load_only_mask(self, example: str) -> None:
assert not r_notloaded.is_loaded
assert np.array_equal(mask_notloaded, mask_loaded)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_get_mask(self, example: str) -> None:
"""
Test that getting mask works properly (similar to _load_only_mask).
"""

# Load raster with and without loading
r_loaded = gu.Raster(example, load_data=True)
r_notloaded = gu.Raster(example)

# Get the mask for the two options
mask_loaded = r_loaded.get_mask()
mask_notloaded = r_notloaded.get_mask()

# Data should not be loaded and masks should be equal
assert not r_notloaded.is_loaded
assert np.array_equal(mask_notloaded, mask_loaded)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path]) # type: ignore
def test_to_rio_dataset(self, example: str):
"""Test the export to a rasterio dataset"""
Expand Down

0 comments on commit 914ddf3

Please sign in to comment.