Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make loss functions and regularizers classes that inherit from torch.nn #131

Open
iancze opened this issue Feb 3, 2023 · 3 comments
Open

Comments

@iancze
Copy link
Collaborator

iancze commented Feb 3, 2023

Currently our loss functions are coded as straightforward functions working on torch inputs. Some loss functions have additional parameters that are set at initialization, for example,

def entropy(cube, prior_intensity): ...
where prior_intensity is a reference cube used in the evaluation of the actual target, cube.

This works fine, but can get cumbersome, especially when we are interfacing with a bunch of loss functions all at once (as in a cross-validation loop).

Can we take a page from the way PyTorch designs its loss functions and make most if not all loss functions classes that inherit from torch.nn? This would create objects that could be instantiated with default parameter values easily and generalize the calls to each parameter. For example, see MSE Loss.

This may have additional benefits (with reduce, say) if we think about batching and applications to multiple GPUs.

Does it additionally make sense to include the lambda terms as parameters of the loss object, too? @kadri-nizam do you have any thoughts from experience w/ your VAE architecture?

@kadri-nizam
Copy link
Contributor

I think making versions of the loss functions as torch modules is a great idea. I'd still keep the functional definition separate and import them when defining the module as it is more flexible (for developing and testing). I believe this is how PyTorch implements it; logic for the losses are in torch.nn.functional which gets used in the nn.Module version.

Does it additionally make sense to include the lambda terms as parameters of the loss object, too?

The lambda parameter doesn't change throughout the optimization run, right? If so then I'd include it in the argument during instantiation.

@iancze
Copy link
Collaborator Author

iancze commented Mar 5, 2023

Yes, the 'lambda' parameter remains fixed during an optimization run. In a cross-validation loop you'd want to try several different lambda values, so in that situation I guess you would need to re-instantiate the loss functions.

@kadri-nizam
Copy link
Contributor

Hi Ian,

Thank you for a productive discussion today! Here's an example of how I implemented the loss functions in my fork:

import torch.nn as nn

class TV(nn.Module):
    def __init__(self, λ: float, /) -> None:
        super().__init__()
        self.λ = λ

    def __repr__(self):
        return f"TV(λ={self.λ})"

    def forward(self, image: torch.Tensor):
        return self.λ * TV.functional(image)

    @staticmethod
    def functional(image: torch.Tensor) -> torch.Tensor:
        row_diff = torch.diff(image[:, :-1], dim=0).pow(2)
        column_diff = torch.diff(image[:-1, :], dim=1).pow(2)
        return torch.add(row_diff, column_diff).sqrt().sum()

The purpose for having the functional static method is to allow for easier testing -- just call TV.functional instead of the need to instantiate and all that.

I defined an abstract base class in my fork to specify requirements that a loss module in the repo must meet, but this is optional.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants