Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Loss signatures: CE Loss failure because of additional params argument #854

Open
VukW opened this issue Apr 23, 2024 · 1 comment
Open
Labels
bug Something isn't working enhancement New feature or request

Comments

@VukW
Copy link
Contributor

VukW commented Apr 23, 2024

Describe the bug

During loss computations
it is assumed that loss function takes three params: prediction, target, params. However, it's not true for CE loss that takes only prediction and target, so using loss_function: ce fails.

To Reproduce

Steps to reproduce the behavior:
try to train any model with loss_function: ce

Expected behavior

A clear and concise description of what you expected to happen.

Media

If applicable, add images, screenshots or other relevant media to help explain your problem.

Environment information

GaNDLF version, OS, and any other relevant information.

Additional context

The straightforward solution is just to add an unused params arg to CE function. However, I believe, doing this would cause linter / codacy failures as parameter is defined but not used. In this case the best option IMO is to create a standard class interface for losses:

from abc import ABC, abstractmethod
class LossInterface(ABC):
    @staticmethod
    @abstractmethod
    def calc(predictions: torch.Tensor, targets: torch.Tensor, params: dict) -> torch.Tensor:
        raise NotImplementedError()

class DCCE(LossInterface):
    @staticmethod
    def calc(predictions: torch.Tensor, targets: torch.Tensor, params: dict) -> torch.Tensor:
        ... move DCCE calculation logic there...

# in loss_and_metric.py it can be used as:
# loss_function.calc(predictions, targets, params)

and the same with all other losses. In this case all the losses would have the same signature and can be used interchangeably. If signature of any loss function differs, both Codacy and IDE would warn developer that something goes wrong.

@sarthakpati
Copy link
Collaborator

The solution makes complete sense to me. We should do this for all the losses, not just ce. And on that note, perhaps ce is a bit ambiguous, and we should make it explicit: either CEL (i.e., cross entropy loss), BCEL (i.e., binary cross entropy loss), or BCEL_logits (i.e., binary cross entropy with logits).

Thoughts?

@sarthakpati sarthakpati added bug Something isn't working enhancement New feature or request labels Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants