Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patchwise training support #75

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

nilsleh
Copy link

@nilsleh nilsleh commented Oct 12, 2023

This PR aims to close #22 by implementing an option to run patch wise training.

The current approach is to expect normalized coordinates as a patch size sequence argument for the x1 and x2 dimension. The current patch size sampling strategy is random uniform sampling.

The way I have currently thought about supporting patch wise training is the following:

  • add a class method to TaskLoader which samples a uniform point in the normalized coordinate frame and takes the patch size to define a "bounding box" around that sampled form
  • subsequently check that for the sampled box there is context and target data available, a caveat here is the while loop which is at risk of running very long if samples are scarce
  • if a bbox fullfills those criteria it is used in the context and target sampling
    • for xarrays it is slice isel statement
    • for df its more involved because slice expects ordered columns of the coordinates which cannot be guaranteed
  • the sampled bbox is used to sample both the context and target set
  • if no patch_size is specified to the task loader call, there are no changes, default is None so everything should run as before

TODO:

  • correct visualization context encoding
  • work with unnormalized coordinates
  • patch wise inference that does stitching
  • unit tests for sampling strategy with numpy array, as the interp coordinates are not necessarily known beforehand to respect the random sampled bounds

@nilsleh nilsleh marked this pull request as draft October 12, 2023 12:51
@tom-andersson
Copy link
Collaborator

Thank you very much for opening this PR @nilsleh. Addressing #22 will be a significant addition to DeepSensor's functionality. It is really appreciated that you've taken the time to try tackling this.

I will start adding some high-level line comments. But firstly, a general point: In DeepSensor, I distinguish between 'slicing' a variable and 'sampling' a variable. In the TaskLoader the variables (xarray/pandas data) are first temporally sliced to specific date/s, and these smaller, sliced xarray/pandas objects are then passed to be sampled. 'Sampling' is the process of selecting a subset of (x, y) points from the set of all (x, y) points (with various sampling schemes available via the context_sampling and target_sampling kwargs). I see that you've added the slicing as part of sample_df and sample_da. We should instead consider the patching behaviour simply as a spatial slicing pre-processing step. I would do this like so:

# Temporal slice (already in TaskLoader)
context_slices = [
    self.time_slice_variable(var, date, delta_t)
    for var, delta_t in zip(self.context, self.context_delta_t)
]
target_slices = [
    self.time_slice_variable(var, date, delta_t)
    for var, delta_t in zip(self.target, self.target_delta_t)
]

# Spatial slice (to be added in this PR)
context_slices = [
    self.spatial_slice_variable(var, window)
    for var in context_slices
]
target_slices = [
    self.spatial_slice_variable(var, window)
    for var in target_slices
]

General comments about PRs

  • Please run pip install -r requirements.dev.txt
  • From project root, run pytest to check unit tests
  • From project root, run black deepsensor/ and black tests/
  • Ensure any docstrings adhere to Google style so that sphinx can generate docs from them.


:return sequence of patch spatial extent as [lat_min, lat_max, lon_min, lon_max]
"""
# assumption of normalized spatial coordinates between 0 and 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't assume data is bounded in [0, 1]. This is not guaranteed or enforced in any part of the DeepSensor data processing pipeline. Instead, we need a new method, run during the TaskLoader init, which computes the global min/max coordinate values of the context/target data, and then the central point of the patch should be sampled uniformly in this range.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, to my understanding the TaskLoader only works on already normalized/standardized data and the coordinate bounds were normalized to [0,1] but that is good to know, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, the DataProcessor linearly normalises the coords of the first data variable it is provided with to lie in [0, 1], but subsequent variables may exceed that data range. Thus, although the data coords will typically lie in [0, 1], there is nothing constraining this to always hold.

@@ -1161,58 +1269,26 @@ def sample_variable(var, sampling_strat, seed):
f"with the `links` attribute if using the 'gapfill' sampling strategy"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed a bunch of missing code here. I'm confused because your branch is only 2 commits behind main (and these 2 commits are irrelevant). Is there any chance you rejected some changes from main when merging?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I noticed that an commited it already, but I think you were already reviewing in between, should be fixed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the below section in this diff the missing section? @nilsleh could you help me understand what's happening to the below functionality please? Has it been moved somewhere, is it no longer required? Thanks.

@@ -881,6 +974,9 @@ def task_generation(
"split" sampling strategy for linked context and target set pairs.
The remaining observations are used for the target set. Default is
0.5.
patch_size: Sequence[float], optional
Desired patch size in lat/lon used for patchwise task generation. Usefule when considering
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few references to lat/lon specifically. Please instead use the DeepSensor standardised coordinate names x1/x2 in comments and variables. The TaskLoader operates only on standardised/normalised data.

@@ -1226,7 +1302,7 @@ def sample_variable(var, sampling_strat, seed):
X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
Y_c_aux = (
self.sample_offgrid_aux(
X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date)
X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date), sample_patch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to spaitally slice offrid aux; this will happen implicitly because the context data used for sampling the self.aux_at_contexts xarray data will already have been spatially sliced.

lon_side = lon_extend / 2

# sample a point that satisfies the boundary and target conditions
continue_looking = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove the continue_looking logic entirely. Firstly, it's fine if the patch contains no context data; DeepSensor models should be able to handle this. The main risk here is that the patch contains no target data, which can lead to NaNs when passed to the ConvNP.loss_fn. However, it is much, much, easier to check for Tasks with no target data as a training pre-processing step. This would be a separate PR or something we expect the user to be aware of.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an assumption about a common coordinate range between the context and the target? Because if so, we can gather the coordinate bound extend of the target variable and use that to do the random window sampling?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, unfortunately we can't assume that. We'll have to loop over all the self.context and self.target variables updating the min/max data coordinate bounds.

target_slices[target_idx] = target_var
# sample common patch size for context and target set
if self.patch_size is not None:
sample_patch_size = self.sample_patch_size_extent()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we don't make patch_size a class attribute like this; it should only exist in the scope of __call__ here

@tom-andersson
Copy link
Collaborator

Hi @nilsleh, I've submitted my review. I hope you don't mind all the feedback - do let me know if you don't have the time and would prefer me to take over.

As a side note, to close #22 we'll need to solve the 'inference' part by adding a patch-processing feature to DeepSensorModel.predict.

f"Must be one of [None, 'random', 'sliding']."
)

if patch_strategy is None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the logic to the _call_ function, however, there is quiet a bit of code redundancy because:

  • checking separate sampling strategies
  • checking whether one supplies a single data or a sequence of date that determines whether a Task or a list[Task] is returned

So that can be made more concise


# TODO it would be better to do this with pytest.fixtures
# but could not get to work so far
task = tl(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have fixtures that generate the data setup and then we can test with different configurations like single date, list of dates, different context and target sampling strategies etc.

loss = np.mean(epoch_losses)
self.assertFalse(np.isnan(loss))

def test_sliding_window_training(self):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the trainer tests could also benefit from fixtures, because at the moment there is a lot of code duplication. Maybe fixtures that generate a list of train_tasks that are then run with a ConvCNP and Trainer

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Patchwise training and inference
3 participants