Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch implementation #1

Open
wants to merge 22 commits into
base: dev/pytorch
Choose a base branch
from

Conversation

Dariusrussellkish
Copy link

@Dariusrussellkish Dariusrussellkish commented Dec 13, 2020

Hi all, great talk and paper.

I did the preliminary work of porting this to PyTorch. There are a few niceties that could be further implemented like specifying batch dimension and some customization with reduction, and how huggingface/transformers has both tf and torch implementations without requiring both as dependencies.

Otherwise it's all there for NIG. I didn't implement the Dirichlet_SOS loss since it wasn't clear where it would be used. I'll work on porting the NeurIPS examples but since that will take a while, I figured it would be useful to give the base torch code for now.

Of note: I found some numerical instabilities/issues with the student t distribution when the model has very confident regressions. Just to test the torch version, I did SGD on a simple contrived linear regression and found that as the model achieved a strong fit, its probabilities via student t went > 1 (and nll went < 0). It obviously doesn't hinder the training, but it seems a bit off to have a calculation produce probabilities > 1.

Pytorch has an implementation of StudentT, which also suffers from the same instabilities but are numerically different. I went with directly porting the TF code for numerical consistency.

Something like the below will recreate this instability.

import torch
import evidential_deep_learning as edl
import numpy as np
xs = [torch.rand(8) for i in range(10_000)]
ys = [x.mean() * np.pi for x in xs]
class BasicNetwork(torch.nn.Module):
    def __init__(self, n_in=8, n_tasks=1):
        super(BasicNetwork, self).__init__()
        self.l1 = torch.nn.Linear(n_in, 6)
        self.l2 = edl.pytorch.layers.DenseNormalGamma(6, 1)
        
    def forward(self, x):
        x = self.l1(x)
        x = torch.nn.functional.relu(x)
        x = self.l2(x)
        return x

model = BasicNetwork()
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
for i in range(1):
    for x, y in zip(xs, ys):
        output = model(x)
        nll, reg = edl.pytorch.losses.EvidentialRegression(y, output)
        with torch.no_grad():
            print((-1 * nll).exp().item())
            print(nll.item())
            print(reg.item())
            print(y)
            print(output)
            print((y - output[0]).abs().item())
        loss = nll + reg
        print(loss.item())
        print()
        loss.backward()
        optim.step()
        optim.zero_grad()

@aamini
Copy link
Owner

aamini commented Dec 14, 2020

First of all, thank you for contributing the new pytorch implementation! We really appreciate this initiative on your part.

I would like to take some time to review your updates before merging in to master / PyPi. I can provide some comments on the PR and (if it's okay with you) commit some suggestions to your branch before merging. In the meantime, I believe your code will serve as a great starting point for others who would like to try the method in pytorch.

@Dariusrussellkish
Copy link
Author

Hi! I'd actually suggest you make a pytorch dev branch and I'll re-PR into that. I always wonder why that isn't an option on github since it's more logical. That way I can also revert all of the neurips2020 import refactoring and leave all of the original code untouched.

Please do give suggestions - I'm especially not happy at the moment with how the two implementations handle the namespace. I think there's potentially a much cleaner way (for the user) to go about this and it's definitely not ready to merge into main as is.

@aamini aamini changed the base branch from main to dev/pytorch December 14, 2020 20:16
@aamini
Copy link
Owner

aamini commented Dec 14, 2020

Good point, I just created the pytorch dev branch (dev/pytorch) and modified the target branch of your repo to point to it instead of master.

I agree with the namespace issue. I think one way to cleanly handle this is like how keras used to handle multiple backends and read the backend from an OS variable (doc). Similar approach is adopted by pyrender. Alternatively, we could adopt an approach similar to how matplotlib works (they have a base method that allows to switch backends).

This reverts commit 24cc9aa.
This reverts commit f69c59f.
Automatically detects torch and tf availability
- error when neither is available
- when only one is available, allows only that backend
- when both are available, default to tf backend
- set_backend('tf'|'torch) manipulates edl.loss and edl.layers namespace
@Dariusrussellkish
Copy link
Author

Dariusrussellkish commented Dec 14, 2020

  1. Reverted everything with NeurIPS directory
  2. See commit ce9f606. This is a bit messy but uses the matplotlib 'backend' idea. It checks if torch or tf can be imported and then manipulates the edl.layers and edl.losses names to point to the correct implementation. I'm sure that there is a cleaner way.
  3. I implemented Dirichlet_SOS since @benjiachong seems to want it. It's a very 1:1 port and I'm not sure if it is valid. I'm unfamiliar with tf so I'm not sure what dimension the axis=1 is really referring to and what that would then conventionally be in pytorch (it's usually Batchdim, [channeldim]?, Datadim, [Datadim2]?, ...)

I think the next steps would be first validating the pytorch code and then finding a cleaner way to handle the namespace.

@Dariusrussellkish
Copy link
Author

Dirichlet UQ for discrete classification is implemented and validated @benjiachong

@wonjeongchoi
Copy link

Hi, @Dariusrussellkish.
As you comment on 14 Dec 2020, I have also encountered the problem that nll loss goes to negative value.
I found the problem occer when the evidence value(nu, alpha) go high and so the variance becomes to be small.

Did you solve this problem in pytorch implementation?

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants