Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to avoid nodata-only patches #1330

Open
adamjstewart opened this issue May 12, 2023 · 21 comments · May be fixed by #1881
Open

How to avoid nodata-only patches #1330

adamjstewart opened this issue May 12, 2023 · 21 comments · May be fixed by #1881
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets samplers Samplers for indexing datasets

Comments

@adamjstewart
Copy link
Collaborator

adamjstewart commented May 12, 2023

Summary

Large scenes are often rotated due to CRS like so:

LANDSAT_1million_20170531

When using our GeoSamplers, we often sample from the nodata regions around the edges.

Rationale

Sampling these nodata-only patches results in slower I/O and slower GPU training, despite these patches contributing nothing to the model.

Implementation

There are two places where we could potentially improve performance.

I/O

Best case scenario would be to avoid sampling these regions entirely. In #449, someone tried adding a check to the sampler that actually loads the patch, checks if it's entirely nodata pixels, and skips it if so. Unfortunately, this doesn't seem to work in parallel, and is slow since it needs to load each patch twice.

If there was a way to load the image in its native CRS (i.e., a square with no nodata pixels in the orientation it was taken in), this would solve all of our problems. I don't know of a way to do this.

GPU

This is actually easier to solve. We could add a feature to our data module base classes that removes all nodata-only images from each mini-batch inside transfer_batch_to_device or another step. This would result in variable batch sizes, but I don't think that's an issue.

Alternatives

No response

Additional information

This is a highly requested feature:

@adamjstewart adamjstewart added datasets Geospatial or benchmark datasets samplers Samplers for indexing datasets datamodules PyTorch Lightning datamodules labels May 12, 2023
@adamjstewart
Copy link
Collaborator Author

https://gdal.org/programs/gdal_footprint.html looks very promising. It seems like it's possible to access this information. It's unclear how fast this would be or how we could make use of it in our R-tree.

@maawoo
Copy link

maawoo commented Jan 8, 2024

Hi! Just happend to come across this issue and wanted to add this as a relevant resource:
https://www.element84.com/geospatial/the-stactools-raster-footprint-utility/

@johnnv1
Copy link

johnnv1 commented Feb 11, 2024

Personally, at the moment, I'm using read_masks, before try to load band data, to avoid/skip nodata patches, have being my approach "on the fly"... but agree this isn't the best choice, because even though it's much more lightweight to read the mask patch (it's a byte array), it's an extra read necessary for each one.


https://gdal.org/programs/gdal_footprint.html looks very promising. It seems like it's possible to access this information. It's unclear how fast this would be or how we could make use of it in our R-tree.

From the link: "The gdal_footprint utility can be used to compute the footprint of a raster file, taking into account nodata values (or more generally the mask band attached to the raster bands), and generating polygons/multipolygons corresponding to areas where pixels are valid, and write to an output vector file."

I believe it is something like:

  1. open raster - rasterio.open(...) as src
  2. load the nodata mask - src.read_masks(...)
  3. transform the mask into a polygon - cv2 contours? rasterio.feature?
  4. [optional here] transform the polygon coordinate to be based on the CRS

I'm not entirely familiar with torchgeo codebase, the samplers have access to the raster sample itself before creating the sampler window/roi? perhaps an option is: for each raster, load the nodata mask at once, create this polygon, and then make it possible for the sampler to only work with the "valid" region


(🤔 , reading the nodata mask (the byte array), using patches may is not a good option because the I/O overhead when doing this to sample the entire raster, will take >>> longer than reading this at once)

@adamjstewart
Copy link
Collaborator Author

We tried something similar to this before (reading files inside the sampler) but rasterio had issues with parallel locking. Did you not encounter any issues with your solution when using multiple workers?

@johnnv1
Copy link

johnnv1 commented Feb 11, 2024

I'm thinking something like loading the nodata at dataset level, them just accessing it instead of each sampler looking into the files


Edit:

As I said, I am not familiar with the code base

Assuming the dataset is a group of samples... what I'm thinking is:

When the first access occurs (or sampling), or in the indexing the stage:

  • open the raster
  • retrieve the nodata mask as a boolean array
  • close the raster
  • compute the nodata polygon
  • store the polygon (or as an oriented bounding box) on the dataset based on index. It will be just an extra Metadata/property for each sample.

Then, when we go sample item N of the dataset, we just access the N polygon/bb if available. I don't see it causing a memory issue by using an oriented bounding box to store it. But it can be cached someway for gigant datasets.

I don't know if it is stored the CRS of each here, but it would be the same idea of having a multi crs dataset where we need to handle it

@adamjstewart
Copy link
Collaborator Author

The problem is that R-tree does not support oriented bounding boxes. We would have to replace R-tree with something else. For the record, I'm completely fine with doing this, just don't know what else would work well for this use case.

Also, note that this will make instantiating the dataset extremely slow because it needs to read every single file just to populate the index. We'll have to benchmark this to see if it's better or worse than doing it in the sampler. We may be able to parallelize this with multiprocessing though.

@johnnv1
Copy link

johnnv1 commented Feb 11, 2024

Yeah, it's probably best to do this the first time we try to access the sample and then when indexing the dataset... as it's using rtree, I believe, it should be using some of their functionalities... some doubts:

  • it didn't even support a generic single parameter/value? to store theta value
  • just storing the theta in a list/tensor isn't an option?

@adamjstewart
Copy link
Collaborator Author

We can store theta, but we can't check for overlap using theta. We would have to write our own check to find the valid regions where we can sample from.

@johnnv1
Copy link

johnnv1 commented Feb 13, 2024

Yes, which I believe leads to the same case/function of #1190

@adriantre
Copy link
Contributor

Geosampler outside of valid data mask

There are currently two situations where nodata will be sampled.

  • Datasets with multiple CRS, as already mentioned in this issue
  • Datasets with "internal" nodata, like Sentinel-2 because each satellite pass cropped into a fixed grid (MGRS), see example above

Reprojection will not solve point two.

A possible solution for both cases:
Step 1. if we had (quick) access to the footprint of valid pixels, or this was precomputed/available
Step 2. verify actual intersection with footprint when fetching samples in RasterDataset.__getitem__, after finding hits (filepahts).

For Step 1
Could start by adding support for user-provided footprints? E.g. supporting an attribute override in RasterDataset that retrieves these footprints, either from a separate file or metadata in the raster.

Example on how to retrieve footprint in Sentinel-2

import rasterio
from shapely import wkt

with rasterio.open("/<product_id>.SAFE/MTD_MSIL1C.xml") as src:
    valid_pixels_footprint = wkt.loads(src.tags()['FOOTPRINT'])

For Step 2
A sample that is found to not intersect with the query would need to be replaced with another sample. Is there another good way than collate_fn?

I see this as the simplest solution. Alternatives like replacing rtree as index is a bigger task that should be viewed in relation to #409 in my opinion.

@adamjstewart
Copy link
Collaborator Author

Step 1: It's still not clear to me how this valid footprint would be used. Also, I would prefer a solution that works for all datasets, not just Sentinel-2
Step 2: This could also be done in the sampler (although may have issues with multiprocessing) or in the LightningDataModule. See the GPU Implementation I suggested in my original post for the latter. Downside is that we're still doing the I/O.

@adriantre
Copy link
Contributor

adriantre commented Feb 13, 2024

Step 1: It's still not clear to me how this valid footprint would be used.

Here is an example on the how, but I don't know where in the code. But you would need access to the query representing the bounds of the sample/patch, so my logic is upon getitem.

Also, I would prefer a solution that works for all datasets, not just Sentinel-2

One solution, as you have mentioned is using gdal_footprint, and precompute the footprints for all images. Store to some file, and read them back somewhere in the flow, e.g. in the RasterDataset (optimally cache it, or once per raster).

import rasterio
from shapely import wkt
from shapely.geometry import box

def extract_footprint(filepath):
    # Sentinel-2 example showing how to find footprint
    # In the general case could run gdal_footprint beforehand and save to some file
    with rasterio.open(filepath) as src:
        valid_pixels_footprint = wkt.loads(src.tags()['FOOTPRINT'])
        # reproject to dataset crs if it is not
    return

def query_intersects_with_footprint(query, filepath):
    with rasterio.open(filepath) as src:
        valid_pixels_footprint = extract_footprint(filepath)
    bbox = box(query.minx, query.miny, query.maxx, query.maxy)
    return shapely.overlaps(bbox, valid_pixels_footprint)


class RasterDataset(GeoDataset):
    def __getitem__(self, query: BoundingBox):
        hits = self.index.intersection(tuple(query), objects=True)
        filepaths = cast(list[str], [hit.object for hit in hits])

        # remove filepaths that has nodata within sample bbox
        filepaths = [
            path if query_intersects_with_footprint(query, path) 
            for path in filepaths
        ]

        if not filepaths:
            return None

        # ... rest of existing method returns samples
        # then collate_fn can replace/remove None

def concat_samples_replace_none(samples):
    """
    Based on this https://stackoverflow.com/questions/57815001/pytorch-collate-fn-reject-sample-and-yield-another
    """
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    
    collated = concat_samples(samples)  # original collate_fn used by IntersectionDataset

    if len_batch > len(collated): 
        # if there are samples missing just use existing members, 
        # doesn't work if you reject every sample in a batch
        diff = len_batch - len(collated)
        for i in range(diff):
            collated = collated + collated[:diff]

    return collated

@adamjstewart
Copy link
Collaborator Author

"the query representing the bounds of the sample/patch" is created in the sampler. If we can decide whether or not it's a valid location to sample from before getting to the __getitem__, we can avoid loading the data or ask for new locations instead.

@adriantre
Copy link
Contributor

Makes sense! Something like this?

RasterDataset still need to do the same check, since the query provided by GeoSampler does not tell __getitem__ which hit/filepath to read from, right?

class RandomGeoSampler(GeoSampler):
    def __iter__(self) -> Iterator[BoundingBox]:
        for _ in range(len(self)):
            # Choose a random tile, weighted by area
            idx = torch.multinomial(self.areas, 1)
            hit = self.hits[idx]
            bounds = BoundingBox(*hit.bounds)

            # Choose a random index within that tile
            bounding_box = get_random_bounding_box(bounds, self.size, self.res)

            if not query_intersects_with_footprint(bounding_box, hit.object):
                # this bounding_box is outside of the valid-pixel footprint of the raster
                continue

            yield bounding_box

@adamjstewart
Copy link
Collaborator Author

We don't need to do it in RasterDataset because all hits will be merged (stitched together), so if the sampler says the query is valid, it's valid.

Your implementation of query_intersects_with_footprint still uses a (possibly rotated) bounding box. Doesn't that suffer from the exact same issue as the original issue?

@adriantre
Copy link
Contributor

adriantre commented Feb 13, 2024

query/bounding_box returned by GeoSampler is always rectified (non-rotated) to the common CRS of the Dataset?

Assuming the above is correct, My query_intersects_with_footprint checks intersection between this rectified box and the reprojected footprint.

from geopandas import GeoSeries


def extract_footprint(filepath):
    with rasterio.open(filepath) as src:
        valid_pixels_footprint = wkt.loads(src.tags()['FOOTPRINT'])
    return valid_pixels_footprint

def query_intersects_with_footprint(filepath, query, common_crs):
    # Reproject vectordata to the same crs as the GeoSampler grid is based on
    # using geopandas for simplicity
    valid_pixels_footprint_reprojected = (
        GeoSeries(
            geometry=extract_footprint(filepath), 
            crs=4326
        )
        .to_crs(common_crs)
    )

    bbox = box(query.minx, query.miny, query.maxx, query.maxy)

    return valid_pixels_footprint.intersects(bbox).all()  # all() to get scalar bool from GeoSeries

@adriantre
Copy link
Contributor

adriantre commented Feb 13, 2024

So this approach need access to the common CRS used by RasterDataset, and some way to map from the query to a file where it can read the footprint vector.

Can compare agains the unary_union of all footprints for all hits too.

@adamjstewart
Copy link
Collaborator Author

The sampler is given the dataset index, which is already in a common CRS, no need to warp anything yourself. The problem is that images are rotated with respect to almost any CRS, and have significant nodata pixels around the border.

I think I understand your implementation better now. You're not checking the bbox of the image, but of the patch. This should work. I'm just not sure how fast gdal_footprint will be. I've been meaning to add a new subcommand to the torchgeo CLI to help with benchmarking I/O rates before and after changes like this. This might be a good time to implement it.

If you can get gdal_footprint working, even without caching, I would love to review a PR for it. If it's slow, we can make it optional.

@adriantre
Copy link
Contributor

The sampler is given the dataset index, which is already in a common CRS, no need to warp anything yourself. The problem is that images are rotated with respect to almost any CRS, and have significant nodata pixels around the border.

Yes, I think we are trying to explain the same thing here.

I pushed a minimal working (🤞) example that fixes this for RandomGeoSampler. Instead of using gdal_footprint I opted for a run-time equivalent in rasterio.

@adamjstewart
Copy link
Collaborator Author

I'll try to review when I get a chance. I'll likely need to implement an I/O benchmarking subcommand to see how much this affects I/O rates before merging, which I won't have time to get to until March. So don't hold your breath, but I promise I'm interested and will review in detail as soon as I can.

@adamjstewart
Copy link
Collaborator Author

Update: we actually don't need a new I/O benchmarking subcommand, Lightning has built-in support for this!: https://lightning.ai/docs/pytorch/stable/tuning/profiler.html

So all we really need is to:

  1. Decide on one or more exemplary datasets
  2. Add a data module to TorchGeo
  3. Add documentation to our Contributing guide on which flags to use

For 1, in our preliminary TorchGeo paper, we sampled 100 random Landsat scenes and one CDL map, each in a different CRS. There are a million things we can play around with (COGs, block size, resolution, CRS, etc.). We may want to develop a list of multiple options:

  1. Worst case: all files in different CRS, different resolution, heavily compressed, large block size, nodata pixels
  2. Average case: all files in different CRS, COGs, nodata pixels (this is what we used in our paper)
  3. Best case: all files in same CRS/res, no nodata pixels

Essentially, we would be developing a set of benchmark datasets not for benchmarking models, but for benchmarking I/O. I wonder if such a thing already exists. This might actually make for an interesting paper if it doesn't exist. Let me ask around with the GDAL folks.

But anyway, for your contribution, a single dataset should be sufficient. Initially I framed this from the perspective of "as long as your PR doesn't make things significantly slower, it's fine". However, after more thought, it's possible your implementation is actually significantly faster, as it allows us to skip many regions we would have otherwise sampled. So when benchmarking I/O, we should definitely take this into account. I.e., the best measure of I/O speeds is how long it takes GridGeoSampler to iterate over the entire dataset, not how long it takes for a specific number of patches to be loaded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets samplers Samplers for indexing datasets
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants