Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patchwise training and inference #22

Open
tom-andersson opened this issue Jul 8, 2023 · 2 comments · May be fixed by #75
Open

Patchwise training and inference #22

tom-andersson opened this issue Jul 8, 2023 · 2 comments · May be fixed by #75
Labels
enhancement New feature or request thoughts welcome Discussion and feedback is appreciated
Milestone

Comments

@tom-andersson
Copy link
Collaborator

Some deepsensor users may have dense environmental data spanning large spatial areas. For example:

  • Problems with data spanning the whole globe
  • Problems with high-resolution satellite data

In such cases, training and inference with a ConvNP over the entire region of data may be computationally prohibitive. Currently, the TaskLoader will sample context and target data over the entire spatial region that data is available, which could produce OOM issues. So we need to support chopping the data into smaller spatial patches.

Training
Supporting patchwise ConvNP training should just be a matter of updating the TaskLoader to slice the context and target datasets spatially to subsetted squares/regions before proceeding with the TaskLoader.__call__ sampling functionality for generating Task objects. I believe this should be quite simple: for xarray data this would be ds.sel(x1=slice(...), x2=slice(...)), while for pandas data it would be df.loc[slice(...), slice(...)].

Inference
Inference using the high-level DeepSensorModel.predict interface also needs support for patching. This requires functionality to stitch all the individual patch predictions together.

For on-grid xarray prediction, one solution might be to call .predict recursively over all the patches and then concatenate the resulting xr.Datasets into single objects. This would require some kind of patchify bool to control this and avoid infinite recursion within the inner call. Open to other ideas!

However, model predictions could differ substantially from one side of a patch border to another (due to differing context information in each patch). We therefore may need to think about having overlapping patches and averaging model predictions somehow.

Patch size/location question
An open question is how the size and location of the patches should be determined. One option is to have the user pass the patch size in TaskLoader.__call__ or DeepSensorModel.predict, and then the location will be generated randomly unless further kwargs are passed to override this and specify exact x1/x2 spatial bounds.

@tom-andersson tom-andersson added enhancement New feature or request help wanted Extra attention is needed labels Jul 8, 2023
@tom-andersson tom-andersson added this to the v0.3.0 milestone Jul 8, 2023
@tom-andersson tom-andersson added thoughts welcome Discussion and feedback is appreciated and removed help wanted Extra attention is needed labels Aug 23, 2023
@nilsleh
Copy link

nilsleh commented Sep 25, 2023

Hi @tom-andersson , I am attempting to "prototype" this for my use case I have. The question I have is what the interpretation of a patch_size should be (ignoring the location of the patch for the moment and just going with random). In the case of spatial Xarray and the example of patch_size=100 and an xarry spanning the globe of dimension [180, 360]:

  • should patch size of 100 just yield an array of dimension [100, 100], or a desired tuple if patch size is specified for both dimension so [50, 120] if patch_size=(50,120)?
  • shouldn't the resolution of the data array be taken into any consideration, given that one might have spatial data at different resolution? I suppose one could also think about patch size not in absolute values but for example percentages so patch_size=0.5 would yield a random [90,180] array
  • for a dataframe what would the interpretation of patch size be given that it is not spatial data? would you compute spatial bounds on the xarray data arrays based on the patch size and then take those spatial lat/lon bounds as the filter for the dataframe data?

Maybe you have some thoughts/preferences on this that I could take as a further guide for a draft implementation. (I haven't thought about Inference at all at this stage)

@nilsleh nilsleh linked a pull request Oct 12, 2023 that will close this issue
@tom-andersson tom-andersson modified the milestones: v0.3.0, v0.4.0 Oct 17, 2023
@acocac
Copy link
Member

acocac commented Apr 9, 2024

Just to add, the Pangeo ML group has extensively worked in optimising n-dimensional arrays for AI/ML pipelines. I suggest you consider for the patch-wise training to build upon existing developments such as xbatcher and zen3geo python libraries.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request thoughts welcome Discussion and feedback is appreciated
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants