Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request]: Flexibility in label transformations #160

Open
1 of 8 tasks
laserkelvin opened this issue Mar 19, 2024 · 0 comments
Open
1 of 8 tasks

[Feature request]: Flexibility in label transformations #160

laserkelvin opened this issue Mar 19, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@laserkelvin
Copy link
Collaborator

laserkelvin commented Mar 19, 2024

Feature/behavior summary

Given that properties from different datasets can span large dynamic ranges, and/or are very non-Gaussian, we should design a framework for modifying and transforming labels ideally just before loss calculations. As part of this, it may be advantageous to calculate dataset-wide statistics on the fly with caching.

Request attributes

  • Would this be a refactor of existing code?
  • Does this proposal require new package dependencies?
  • Would this change break backwards compatibility?
  • Does this proposal include a new model?
  • Does this proposal include a new dataset?
  • Does this proposal include a new task/workflow?

Related issues

#75 pertains to an issue with normalization not being applied; this solution would supersede it.

Solution description

One solution would be to implement this as a subclass of transform, which mutates data in-place:

class AbstractLabelTransform(AbstractTransform):
    def apply(self, *args, **kwargs):
         ...

    def cache_statistic(self, key, value):
        ...

    def save(self, path):
        ...

On-the-fly statistics could be calculated using a moving-average or something, which is then cached to disk based on the dataset class, and the dataset path. The only issue with this is synchronization: for DDP scenarios, we'd want to make sure statistics are the same across each data loader worker. Could probably do some reduction call, etc.

We can then implement concrete versions of the transforms:

class NormalTransform(AbstractLabelTransform):
     # rescales based on mean/std

class MinMaxTransform(AbstractLabelTransform):
    # rescales to [min, max] of specified value, or dataset

class LambdaTransform(AbstractLabelTransform):
     # this is a bit dicey, but apply an arbitrary function to a key

class ExponentialTransform(AbstractLabelTransform):
     # many properties have long-tailed distributions

The idea would be that you could freely compose these such that different labels can be transformed in different ways.

Alternatively:

  • As a pl.Callback; since it has access to discrete after/before_x_step regions, which could be helpful in getting access to batch data.
  • We could take the existing normalization steps that are being used in _compute_losses. However, caching and whatnot isn't as flexible.

Additional notes

A task list based on the transform-based solution (convert to issues/PRs for tracking):

@laserkelvin laserkelvin added the enhancement New feature or request label Mar 19, 2024
@laserkelvin laserkelvin self-assigned this Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant