Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Lightning Streaming Dataset #1886

Open
tchaton opened this issue Feb 16, 2024 · 15 comments
Open

Add support for Lightning Streaming Dataset #1886

tchaton opened this issue Feb 16, 2024 · 15 comments
Labels
datasets Geospatial or benchmark datasets

Comments

@tchaton
Copy link

tchaton commented Feb 16, 2024

Summary

Dear people of TorchGeo,

I am hearing more folks in the community training in the cloud and struggling with their data streaming solution.

I have been working on this new framework: https://github.com/Lightning-AI/pytorch-lightning/tree/master/src/lightning/data.

I wondered if you would be interested to give it a spin.

Rationale

Make it easier for people to consume datasets in the cloud

Implementation

No response

Alternatives

No response

Additional information

No response

@adamjstewart
Copy link
Collaborator

@robmarkcole has been exploring this and may be able to comment on how feasible it is to integrate this into TorchGeo.

@adamjstewart adamjstewart added the datasets Geospatial or benchmark datasets label Feb 17, 2024
@robmarkcole
Copy link
Contributor

It is pretty straightforward in 2 steps: (1) process the dataset into the required binary format and host on a cloud - I am using AWS, TBC which could would host this for these public datasets, (2) each dataset then has a complementary streaming version

@adamjstewart
Copy link
Collaborator

each dataset then has a complementary streaming version

This is what I'm trying to avoid. I don't want 2+ copies of all 75+ datasets in TorchGeo. I'm fine with adding a new subclass of GeoDataset or RasterDataset though.

@robmarkcole
Copy link
Contributor

I'm not sure there is a way around it. In my implementation the data module accepts a 'streaming' arg and returns the regular or streaming dataset

@tchaton
Copy link
Author

tchaton commented Feb 17, 2024

Yes, unfortunately, having a copy is the only way to make things fast. This is maybe something we could help with. I will come back to you.

@isaaccorley
Copy link
Collaborator

The fMoW dataset would be a good use case for this. It's hosted in s3 and is already preprocessed into individual image patches instead of larger tiles.

@adriantre
Copy link
Contributor

Does Lightning Streaming Dataset take into account random reading of subregions within a file? In geospatial dataformats there are some caveats to this when choosing low file size (.jp2) vs fast random reading (.geotiff).

@tchaton
Copy link
Author

tchaton commented Feb 20, 2024

Hey @adriantre. Right now, the entire file is downloaded and the window is applied locally. However, I think it would be interesting to add slicing on s3 directly if possible.

@adriantre
Copy link
Contributor

I see! When training using torchgeo, say batch_size = 8, it is common that each slice/sample is read from 8 different images. Each image are commonly 1GB+ size. Then next batch is read from either new or the same images.

Indeed, slicing on s3 directly is possible (for GDAL-compliant dataformats).

@tchaton
Copy link
Author

tchaton commented Feb 21, 2024

Hey @adriantre, thanks. That's fascinating. I will have a look and see if we can add support for window fetching directly. This would be super neat. Do you know how they do it under the hood ? Otherwise, I will investigate.

But yeah, for the time being. This provides value when the files are smaller than the chunk size e.g the dataset has being pre-processed into smaller tiles already.

@adamjstewart
Copy link
Collaborator

When training using torchgeo, say batch_size = 8, it is common that each slice/sample is read from 8 different images.

This is only true if you use RandomGeoSampler. We implemented RandomBatchGeoSampler for this exact scenario. With the latter, each mini-batch will only consist of patches from a single image.

But I agree that windowed-reading support (such as implemented by GDAL/rasterio) within S3 would be nice for streaming.

@adriantre
Copy link
Contributor

Hey @adriantre, thanks. That's fascinating. I will have a look and see if we can add support for window fetching directly. This would be super neat. Do you know how they do it under the hood ? Otherwise, I will investigate.

To my understanding, it is enabled by the file formats dividing and storing big files as "smaller files" called blocks. The blocks can be accessed in parallell threads and can efficiently be accessed without scanning through the whole file. Kinda like an index in a database (r-tree). The block size is an optimisation parameter that influence random-reading speed and writing speed and file size. And the per-fileformat-driver knows how to utilise this. Thats as far as my understanding goes thought 😅

Might be some more in-depth info here:
http://ikcest-drr.osgeo.cn/tutorial/k1072

@tchaton
Copy link
Author

tchaton commented Feb 22, 2024

Interesting, similar to what the streaming dataset does with the chunks. However, I am chatting with AWS Team to support a more native solution.

@adriantre
Copy link
Contributor

What I did not mention, that torchgeo to some degree relies on, is that the drivers for reading these datasets has wrappers that let us read windows from the dataset by specifying geospatial coordinates (not only pixel bounds). And the returned dataset handler let us reproject the pixels to a desired spatial reference system and resolution.

I think it is worth keeping this in mind for the implementation in Lightning Streaming Dataset to lower the barrier for torchgeo to adapt it to rasterio/gdal.

@robmarkcole
Copy link
Contributor

I've now created a couple of streaming datasets and have a common class that can handle them all - it just requires the datasets are formatted in a standard way as a dict, which actually is the same format required for transforms anyway:

import torch
from litdata import StreamingDataset
from rasterio.io import MemoryFile


class SegmentationStreamingDataset(StreamingDataset):
    """
    Segmentation dataset with streaming.

    Args:
        input_dir (str): Local directory or S3 location of the dataset
        transforms (Optional[Callable]): A transform that takes in an image and returns a transformed version.
    """

    def __init__(self, *args, transforms=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.transforms = transforms

    def __getitem__(self, index) -> dict:
        data = super().__getitem__(index)
        image_name = data["name"]
        image = data["image"]
        mask = data["mask"]

        with MemoryFile(image) as memfile:
            with memfile.open() as dataset:
                image = torch.from_numpy(dataset.read()).float()

        with MemoryFile(mask) as memfile:
            with memfile.open() as dataset:
                mask = torch.from_numpy(dataset.read()).long()

        sample = {"image": image, "mask": mask, "image_name": image_name}
        if self.transforms is not None:
            sample = self.transforms(sample)
        return sample

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets
Projects
None yet
Development

No branches or pull requests

5 participants