Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second order computations for nn.Upsample #254

Open
FrederikWarburg opened this issue Apr 12, 2022 · 1 comment
Open

Second order computations for nn.Upsample #254

FrederikWarburg opened this issue Apr 12, 2022 · 1 comment

Comments

@FrederikWarburg
Copy link

Hi

I need to compute the approximate hessian for a decoder network. The decoder consists of conv2d and upsample layers. Currently, backpack does not supports nn.Upsample. Since it is a non-parametric layer, it might not be too difficult to implement?

Here I define my model and a data point.

from backpack import backpack
from backpack.extensions import DiagGGNExact

model = torch.nn.Sequential(
    torch.nn.Conv2d(1,8, kernel_size=3, padding=1),
    torch.nn.MaxPool2d(2),
    torch.nn.ReLU(),
    torch.nn.Conv2d(8,8, kernel_size=3, padding=1),
    torch.nn.Upsample(scale_factor=2, mode="nearest"),
    torch.nn.ReLU(),
    torch.nn.Conv2d(8,1, kernel_size=3, padding=1),
    torch.nn.Flatten(),
)
lossfunc = torch.nn.MSELoss()

model = extend(model)
lossfunc = extend(lossfunc)

X = torch.zeros(1,1,8,8)
print(model(X).shape)

b = X.shape[0]
loss = lossfunc(model(X), X.view(b, -1))

with backpack(DiagGGNExact()):
    loss.backward()

for param in model.parameters():
    print(param.diag_ggn_exact)

will return this error

NotImplementedError: Extension saving to diag_ggn_exact does not have an extension for Module <class 'torch.nn.modules.upsampling.Upsample'>

Could you help implement this feature?

@f-dangel
Copy link
Owner

Hi, thanks for your feature request.

we have an example how to add new parameterized layers to first-order extensions. It's a good starting point. Since nn.Upsample has no parameters, you only have to implement how information for DiagGGNExact is backpropagated through the layer.

To do that, you would

  • create a class DiagGGNUpsample that inherits from ModuleExtension
  • implement its backpropagate function to multiply the backpropagated quantity by nn.Upsample's transposed Jacobian.
  • register the module extension in BackPACK's DiagGGN extension so that BackPACK knows to call it when the extension encounters a nn.Upsample module.

It would be great if you gave it a shot and submitted a PR! I can provide more pointers to help.

Best, Felix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants