Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical instability in ResNet50 heatmaps #148

Open
rodrigobdz opened this issue Jun 15, 2022 · 12 comments
Open

Numerical instability in ResNet50 heatmaps #148

rodrigobdz opened this issue Jun 15, 2022 · 12 comments
Labels
model compatibility Compatibility for new or variations of existing models

Comments

@rodrigobdz
Copy link
Contributor

rodrigobdz commented Jun 15, 2022

Calculating the relevance on ResNet50 seems to be prone to a numerical instability, producing heatmaps where all attribution is concentrated in a few spots because the values in those spots have become larger than the rest. See heatmap in bug reproduction section. I can confirm that this unexpected behavior also happens using different composites and different images.

I have also seen this issue on VGG16 in my own LRP implementation depending on the heuristic used in the stabilize function.

Bug reproduction

Code based on snippet provided in #76 (comment).

Minimal reproducible example:

import cv2
import numpy
import torch
from matplotlib import pyplot as plt
from torchvision.models import resnet50
from zennit.composites import EpsilonGammaBox
from zennit.image import imgify
from zennit.torchvision import ResNetCanonizer


# use the gpu if requested and available, else use the cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class BatchNormalize:
    def __init__(self, mean, std, device=None):
        self.mean = torch.tensor(mean, device=device)[None, :, None, None]
        self.std = torch.tensor(std, device=device)[None, :, None, None]

    def __call__(self, tensor):
        return (tensor - self.mean) / self.std


# mean and std of ILSVRC2012 as computed for the torchvision models
norm_fn = BatchNormalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225), device=device)

batch_size = 1
# the maximal input shape, needed for the ZBox rule
shape = (batch_size, 3, 224, 224)

# the highest and lowest pixel values for the ZBox rule
low = norm_fn(torch.zeros(*shape, device=device))
high = norm_fn(torch.ones(*shape, device=device))


model = resnet50(pretrained=True)
model.eval()

# create the composite from the name map
composite = EpsilonGammaBox(low=-high, high=high, canonizers=[ResNetCanonizer()])

R = None
with composite.context(model) as modified_model:
    # compute attribution
    # Returns a numpy array in BGR color space, not RGB
    img = cv2.imread('castle.jpg')

    # Convert from BGR to RGB color space
    img = img[..., ::-1]

    # img.shape is (224, 224, 3), where 3 corresponds to RGB channels
    # Divide by 255 (max. RGB value) to normalize pixel values to [0,1]
    img = img/255.0
    
    data = norm_fn(
        torch.FloatTensor(
            img[numpy.newaxis].transpose([0, 3, 1, 2])*1
        )
    )
    data.requires_grad = True

    output = modified_model(data)
    output[0].max().backward()

    # print absolute sum of attribution
    print(data.grad.abs().sum().item())

    # relevance scores
    R = data.grad

    # show maximum and minimum attribution
    print(torch.aminmax(R))

    heatmap = imgify(
        R.detach().cpu().sum(1),
        symmetric=True,
        grid=True,
        cmap='seismic',
    )
    
    plt.imshow(heatmap)

Input(s):

  • Input image:

    castle.jpg

Outputs:

  • Text:

    755226.0625
    torch.return_types.aminmax(
    min=tensor(-7540.4922),
    max=tensor(2985.0886))
  • Heatmap:

    resnet50-heatmap

Additional information

The bug is not limited to the castle.jpg image, it can also be reproduced using the following image. See the corresponding heatmap below.

  • Input image:

    castle2

  • Heatmap:

    resnet50-castle2-heatmap

@chr5tphr
Copy link
Owner

chr5tphr commented Jun 15, 2022

Hey Rodrigo,
thank you for raising this issue!
Just a short message for now, I will investigate this more thoroughly later.

I have seen this behavior before with the Gamma rule used in combination with the ResNet-Canonization (replacing the residual connections and weighting them by contribution) in different implementations unrelated to Zennit.
I think this might be a problem with this setup specifically.
Therefore it was rather low on my priority list, as I perceived this more as an issue of the compatibility with EpsilonGammaBox and ResNet.

I know for a fact that the same instability also happens if you choose the ZBox bounds too low, much lower than the actual bounds of the data.

Do you have a specific setup for VGG16? At which epsilon did this happen? For VGG16 I have not seen it before, except for the aforementioned bounds.

Could you also list the specific Composites for which you have seen this behaviour?
I have not seen this happening for anything that did not use the Gamma rule (see the example heatmaps for different composites for ResNet50 in the README).

@rodrigobdz
Copy link
Contributor Author

rodrigobdz commented Jun 15, 2022

Thank you for the prompt reply, Christopher!

Issue Reproduction

The snippet above reproduces the instability with the following setup:

  • Model: ResNet50
  • Composite: EpsilonGammaBox

Could you also list the specific Composites for which you have seen this behaviour?

The following snippet uses the following setup:

Bug reproduction

import cv2
import numpy
import torch
from torch.nn import AvgPool2d, Conv2d, Linear
from torchvision.models import resnet50
from zennit.composites import EpsilonGammaBox, NameMapComposite
from zennit.core import BasicHook, collect_leaves, stabilize
from zennit.rules import Epsilon, Gamma, ZBox
from zennit.torchvision import ResNetCanonizer
from matplotlib import pyplot as plt

from zennit.image import imgify


# the LRP-Epsilon from the tutorial
class GMontavonEpsilon(BasicHook):
    def __init__(self, stabilize_epsilon=1e-6, epsilon=0.25):
        super().__init__(
            input_modifiers=[lambda input: input],
            param_modifiers=[lambda param, _: param],
            output_modifiers=[lambda output: output],
            gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(
                outputs[0] + epsilon * (outputs[0] ** 2).mean() ** .5, stabilize_epsilon)),
            reducer=(lambda inputs, gradients: inputs[0] * gradients[0])
        )

# use the gpu if requested and available, else use the cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class BatchNormalize:
    def __init__(self, mean, std, device=None):
        self.mean = torch.tensor(mean, device=device)[None, :, None, None]
        self.std = torch.tensor(std, device=device)[None, :, None, None]

    def __call__(self, tensor):
        return (tensor - self.mean) / self.std


# mean and std of ILSVRC2012 as computed for the torchvision models
norm_fn = BatchNormalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225), device=device)
batch_size = 1
# the maximal input shape, needed for the ZBox rule
shape = (batch_size, 3, 224, 224)

# the highest and lowest pixel values for the ZBox rule
low = norm_fn(torch.zeros(*shape, device=device))
high = norm_fn(torch.ones(*shape, device=device))


model = resnet50(pretrained=True)
model.eval()

# only these get rules, linear layers will be attributed by the gradient alone
# target_types = (Conv2d, AvgPool2d)
target_types = (Conv2d, AvgPool2d, Linear)
# lookup module -> name
child_name = {module: name for name, module in model.named_modules()}
# the layers in sequential order without any containers etc.
layers = list(enumerate(collect_leaves(model)))

# list of tuples [([names..], rule)] as used by NameMapComposite
name_map = [
    ([child_name[module] for n, module in layers if n == 0 and isinstance(module, target_types)], ZBox(low=low, high=high)),
    ([child_name[module] for n, module in layers if 1 <= n <= 16 and isinstance(module, target_types)], Gamma(0.25)),
    ([child_name[module] for n, module in layers if 17 <= n <= 30 and isinstance(module, target_types)], GMontavonEpsilon(stabilize_epsilon=0, epsilon=0.25)),
    ([child_name[module] for n, module in layers if 31 <= n and isinstance(module, target_types)], Epsilon(0)),
]

# create the composite from the name map
composite = NameMapComposite(name_map, canonizers=[ResNetCanonizer()])

R = None
with composite.context(model) as modified_model:
    # compute attribution
    # Returns a numpy array in BGR color space, not RGB
    img = cv2.imread('castle.jpg')

    # Convert from BGR to RGB color space
    img = img[..., ::-1]

    # img.shape is (224, 224, 3), where 3 corresponds to RGB channels
    # Divide by 255 (max. RGB value) to normalize pixel values to [0,1]
    img = img/255.0
    
    data = norm_fn(
        torch.FloatTensor(
            img[numpy.newaxis].transpose([0, 3, 1, 2])*1
        )
    )
    data.requires_grad = True

    output = modified_model(data)
    output[0].max().backward()

    # print absolute sum of attribution
    print(data.grad.abs().sum().item())

    R = data.grad

    heatmap = imgify(
        R.detach().cpu().sum(1),
        symmetric=True,
        grid=True,
        cmap='seismic',
    )
    
    plt.imshow(heatmap)

Input(s):

  • Input image:

    castle.jpg

Outputs:

  • Text:

    2.66506890811329e+25
  • Heatmap:

    output


Root cause of numerical instability

From my observations, I've narrowed down the issue to the denominators in the equations below.

Equations

Generic LRP rule:
image

Montavon, Grégoire, Alexander Binder, Sebastian Lapuschkin, Wojciech Samek, and Klaus-Robert Müller. "Layer-wise relevance propagation: an overview." Explainable AI: interpreting, explaining and visualizing deep learning (2019): 193-209.

image

Samek, Wojciech, Grégoire Montavon, Sebastian Lapuschkin, Christopher J. Anders, and Klaus-Robert Müller. "Explaining deep neural networks and beyond: A review of methods and applications." Proceedings of the IEEE 109, no. 3 (2021): 247-278.

The issue arises depending on the implementation of the stabilize method and it is input-dependent.

I have tested the following heuristics:

Heuristic implementations

epsilon: float = 0.1
dividend: torch.Tensor = torch.Tensor([-epsilon, 5, -5, -10])
# tensor([ -0.1000,   5.0000,  -5.0000, -10.0000])
  1. Heuristic from zennit: Add epsilon to the absolute value of the dividend conserving the sign:

    dividend + ((dividend == 0.).to(dividend) + dividend.sign()) * epsilon

    Example:

    dividend + ((dividend == 0.).to(dividend) + dividend.sign()) * epsilon
    # tensor([ -0.2000,   5.1000,  -5.1000, -10.1000])
  2. Heuristic from lrp-tutorial: Scale epsilon according to dividend's magnitude using quadratic mean

    dividend + epsilon * (dividend**2).mean()**.5 + 1e-9

    Example:

    dividend + epsilon * (dividend**2).mean()**.5 + 1e-9
    # tensor([ 0.5124,  5.6124, -4.3876, -9.3876])
  3. Vanilla: Add epsilon to dividend without heuristics

    dividend + epsilon

    Example:

    dividend + epsilon
    # tensor([ 0.0000,  5.1000, -4.9000, -9.9000])


Additional insights

Disclaimer: The following images have been generated with my own implementation of LRP. Nevertheless, the error can also be reproduced using zennit by modifying the heuristics in the stabilize function.

It is worth noting that heatmaps that exhibit only few visual relevance concentrations, as shown in the ResNet50 heatmaps in the first comment, do have non-zero relevance scores elsewhere, their visibility strongly depend on the plotting settings—see example below.

Comparison

Heatmap of relevance scores with numerical instability:

image

Same relevance scores with numerical instability, adjusted plotting settings:

image

Heatmap with input in the same plot:

image

Comparing stabilize functions

At which epsilon did this happen?

I used epsilon=1e-6 in the stabilize function. See the following hyperparameter grid search I conducted with the composite from lrp-tutorial:

Composite

image

Heatmaps with no heuristics in stabilize function

image

Heatmaps with lrp-tutorial heuristics in stabilize function

image

@chr5tphr
Copy link
Owner

Hey Rodrigo,

thank you for your insights!
After investigating the instability, I have verified my previous expectation that this heavily dependent on the gamma rule, which can lead to vanishing contributions.

This is also why it is so dependent on the heuristic used in the stabilize function.

In my analysis I have also found that the ResNet residual connections can also lead to vanishing contributions when they are attributed with Norm(), which also requires the stabilizer to be changed accordingly.

In my conclusion, there is no global solution that could be implemented in Zennit.
However, I think exposing more control over the stabilizer in the built-in Composites, especially the epsilon value for Norm(), will greatly help the user. Gamma is otherwise already exposed in the built-in composites. I have implemented multiple heuristics for stabilize, however, the general gist seems to be "lower gamma" -> "higher epsilon" (e.g. through heuristic). I will think a little whether it may be better to pass a function, or the parameters for a feature-rich stabilize function, maybe making both possible would give the most freedom.

As a side note, I have also noticed that for VGG the necessary epsilon grows when the bias is not included, as I am currently preparing the 'no-bias' feature.
The obvious hypothesis here is that when the bias is positive, it acts like a stabilizer itself, similar to a higher epsilon.

Here is also the snippet I used for my analysis. I am thinking of also adding something similiar as a tutorial, to also debug conservativity.

Python code
#!/usr/bin/env python3
import torch
from PIL import Image
from torchvision.models import resnet18, vgg11_bn
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor

from zennit.attribution import Gradient
from zennit.core import collect_leaves, RemovableHandleList
from zennit.composites import EpsilonGammaBox
from zennit.torchvision import ResNetCanonizer, VGGCanonizer
from zennit.image import imsave


def trace_hook(target):
  def trace(module, input, output):
      output.retain_grad()
      target.append((module, output))
  return trace


def main():
  transform_img = Compose([
      Resize(256),
      CenterCrop(224),
  ])
  transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
  transform = Compose([
      transform_img,
      ToTensor(),
      transform_norm,
  ])

  image = Image.open('dornbusch-lighthouse.jpg')
  data = transform(image)[None]

  modules = []

  model = resnet18(pretrained=True).eval()
  canonizer = ResNetCanonizer()

  # model = vgg11_bn(pretrained=True).eval()
  # canonizer = VGGCanonizer()

  low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))
  # low, high = (-3., 3.)
  composite = EpsilonGammaBox(
      low=low,
      high=high,
      zero_params=['bias'],
      gamma=4.,
      epsilon=1e-6,
      canonizers=[canonizer]
  )
  target = torch.eye(1000)[[437]]

  with Gradient(model=model, composite=composite) as attributor:
      handles = RemovableHandleList(
          module.register_forward_hook(trace_hook(modules))
          for module in collect_leaves(model)
      )
      output, attribution = attributor(data, target)
      handles.remove()
  print(f'Prediction: {output.argmax(1)[0].item()}')

  lines = [
      f'{n + 1:03d} {module.__class__.__name__:17s}: '
      f'{(output.grad.max() - output.grad.min()).item():.2e} {output.grad.sum().item():.2e}'
      for n, (module, output) in enumerate(modules)
  ]
  lines.insert(
      0,
      f'000 {"input":17s}: {(attribution.max() - attribution.min()).item():.2e} {attribution.sum().item():.2e}'
  )

  print('\n'.join(lines))

  imsave('heatmap.png', attribution.sum(1)[0], symmetric=True, cmap='coldnhot')


if __name__ == '__main__':
  main()

Do you agree with my observation?

I would then proceed to provide stronger stabilizer customization for Epsilon and Norm, maybe through a Stabilizer class, which can be instantiated with arbitrary parameters, something along the lines of

class Stabilizer:
    def __init__(self, epsilon=1e-6, clip=False, mean_scale=False, dim=None):
        self.epsilon = epsilon
        self.clip = clip
        self.mean_scale = mean_scale
        self.dim = dim

    def __call__(self, input):
        sign = ((input == 0.).to(input) + input.sign())
        epsilon = self.epsilon
        if self.mean_scale:
            dim = self.dim
            if self.dim is None:
                dim = tuple(range(1, input.ndim))
            epsilon = epsilon * ((input ** 2).mean(dim=dim) ** .5)
        if self.clip:
            return sign * input.abs().clip(min=epsilon)
        return input + sign * epsilon

And then I will maybe just check inside the Epsilon and Norm rules whether epsilon is callable or a float, and act accordingly.

@rachtibat
Copy link
Contributor

Hey,

Thank you Rodrigo for putting so much effort into solving the instability and thanks to Christopher for taking the time to answer all the issues in depth.
Although I can not offer a solution, I'd like to add that it would also be interesting to observe the intermediate attribution by accessing the .grad attribute of hidden layer outputs. I also observed that the instability in heatmaps already begins at hidden layers and continues to grow. Unfortunately I can not experiment now with the code as I have no access to my PC.

Best,
Reduan

@chr5tphr
Copy link
Owner

chr5tphr commented Jun 17, 2022

As community service, here is the output for resnet18 with the standard values, but bias deactivated. The first column after the layer name is the distance from the maximum to the minimum value, and the second value is the sum (should stay 1 if conservative)

Default parameters (broken)

image

Gradient Stats:

#   Layer-Name         Max-Min  Sum
--- ------------------ -------- ---------
000 input            : 6.69e+10 -5.21e+06
001 Conv2d           : 5.46e+11 -5.57e+06
002 BatchNorm2d      : 5.46e+11 -5.34e+06
003 ReLU             : 5.46e+11 -5.34e+06
004 MaxPool2d        : 4.41e+11 -5.34e+06
005 Conv2d           : 5.26e+10 +2.69e+09
006 BatchNorm2d      : 5.26e+10 +2.69e+09
007 ReLU             : 5.26e+10 +2.69e+09
008 Conv2d           : 1.57e+09 +2.69e+09
009 BatchNorm2d      : 1.57e+09 +2.69e+09
010 Sum              : 2.52e+09 -1.66e+03
011 ReLU             : 2.52e+09 -1.66e+03
012 Conv2d           : 8.71e+07 +1.24e+07
013 BatchNorm2d      : 8.71e+07 +1.24e+07
014 ReLU             : 8.71e+07 +1.24e+07
015 Conv2d           : 5.47e+06 +1.24e+07
016 BatchNorm2d      : 5.47e+06 +1.24e+07
017 Sum              : 1.07e+07 -8.50e+01
018 ReLU             : 1.07e+07 -8.50e+01
019 Conv2d           : 4.65e+06 +5.90e+04
020 BatchNorm2d      : 4.65e+06 +5.90e+04
021 ReLU             : 4.65e+06 +5.90e+04
022 Conv2d           : 5.57e+05 +5.90e+04
023 BatchNorm2d      : 5.57e+05 +5.90e+04
024 Conv2d           : 3.23e+05 -5.90e+04
025 BatchNorm2d      : 3.23e+05 -5.90e+04
026 Sum              : 6.36e+05 -1.30e+01
027 ReLU             : 6.36e+05 -1.30e+01
028 Conv2d           : 1.93e+05 -1.81e+04
029 BatchNorm2d      : 1.93e+05 -1.81e+04
030 ReLU             : 1.93e+05 -1.81e+04
031 Conv2d           : 6.89e+03 -1.81e+04
032 BatchNorm2d      : 6.89e+03 -1.81e+04
033 Sum              : 9.06e+03 +7.66e-01
034 ReLU             : 9.06e+03 +7.66e-01
035 Conv2d           : 1.79e+03 -7.57e+02
036 BatchNorm2d      : 1.79e+03 -7.57e+02
037 ReLU             : 1.79e+03 -7.57e+02
038 Conv2d           : 9.47e+02 -7.57e+02
039 BatchNorm2d      : 9.47e+02 -7.57e+02
040 Conv2d           : 3.00e+02 +7.58e+02
041 BatchNorm2d      : 3.00e+02 +7.58e+02
042 Sum              : 1.04e+03 +9.90e-01
043 ReLU             : 1.04e+03 +9.90e-01
044 Conv2d           : 8.52e+01 -6.31e+01
045 BatchNorm2d      : 8.52e+01 -6.31e+01
046 ReLU             : 8.52e+01 -6.31e+01
047 Conv2d           : 8.35e+00 -6.31e+01
048 BatchNorm2d      : 8.35e+00 -6.31e+01
049 Sum              : 1.52e+01 +1.00e+00
050 ReLU             : 1.52e+01 +1.00e+00
051 Conv2d           : 7.99e-01 +5.45e-01
052 BatchNorm2d      : 7.99e-01 +5.45e-01
053 ReLU             : 7.99e-01 +5.45e-01
054 Conv2d           : 4.66e-01 +5.45e-01
055 BatchNorm2d      : 4.66e-01 +5.45e-01
056 Conv2d           : 2.44e-01 +4.55e-01
057 BatchNorm2d      : 2.44e-01 +4.55e-01
058 Sum              : 5.57e-01 +1.00e+00
059 ReLU             : 5.57e-01 +1.00e+00
060 Conv2d           : 7.46e-02 +9.42e-01
061 BatchNorm2d      : 7.46e-02 +9.42e-01
062 ReLU             : 7.46e-02 +9.42e-01
063 Conv2d           : 8.87e-03 +9.42e-01
064 BatchNorm2d      : 8.87e-03 +9.42e-01
065 Sum              : 9.18e-03 +1.00e+00
066 ReLU             : 9.18e-03 +1.00e+00
067 AdaptiveAvgPool2d: 1.34e-01 +1.00e+00
068 Linear           : 1.00e+00 +1.00e+00
gamma=4 and epsilon=1e-6 (better)

image

Gradient Stats:

#   Layer-Name         Max-Min  Sum
--- ------------------ -------- --------
000 input            : 6.66e-03 1.00e+00
001 Conv2d           : 8.45e-02 1.00e+00
002 BatchNorm2d      : 8.45e-02 1.00e+00
003 ReLU             : 8.45e-02 1.00e+00
004 MaxPool2d        : 7.11e-02 1.00e+00
005 Conv2d           : 1.71e-01 4.72e-01
006 BatchNorm2d      : 1.71e-01 4.72e-01
007 ReLU             : 1.71e-01 4.72e-01
008 Conv2d           : 8.12e-02 4.72e-01
009 BatchNorm2d      : 8.12e-02 4.72e-01
010 Sum              : 1.19e-01 1.00e+00
011 ReLU             : 1.19e-01 1.00e+00
012 Conv2d           : 4.05e-02 2.71e-01
013 BatchNorm2d      : 4.05e-02 2.71e-01
014 ReLU             : 4.05e-02 2.71e-01
015 Conv2d           : 1.78e-01 2.71e-01
016 BatchNorm2d      : 1.78e-01 2.71e-01
017 Sum              : 2.08e-01 1.00e+00
018 ReLU             : 2.08e-01 1.00e+00
019 Conv2d           : 1.94e-03 6.31e-01
020 BatchNorm2d      : 1.94e-03 6.31e-01
021 ReLU             : 1.94e-03 6.31e-01
022 Conv2d           : 1.32e-02 6.31e-01
023 BatchNorm2d      : 1.32e-02 6.31e-01
024 Conv2d           : 1.03e-02 3.69e-01
025 BatchNorm2d      : 1.03e-02 3.69e-01
026 Sum              : 1.87e-02 1.00e+00
027 ReLU             : 1.87e-02 1.00e+00
028 Conv2d           : 3.88e-03 1.06e-01
029 BatchNorm2d      : 3.88e-03 1.06e-01
030 ReLU             : 3.88e-03 1.06e-01
031 Conv2d           : 1.42e-02 1.06e-01
032 BatchNorm2d      : 1.42e-02 1.06e-01
033 Sum              : 2.10e-02 1.00e+00
034 ReLU             : 2.10e-02 1.00e+00
035 Conv2d           : 4.04e-03 8.76e-01
036 BatchNorm2d      : 4.04e-03 8.76e-01
037 ReLU             : 4.04e-03 8.76e-01
038 Conv2d           : 8.00e-03 8.76e-01
039 BatchNorm2d      : 8.00e-03 8.76e-01
040 Conv2d           : 1.60e-03 1.24e-01
041 BatchNorm2d      : 1.60e-03 1.24e-01
042 Sum              : 7.00e-03 1.00e+00
043 ReLU             : 7.00e-03 1.00e+00
044 Conv2d           : 4.62e-03 1.54e-01
045 BatchNorm2d      : 4.62e-03 1.54e-01
046 ReLU             : 4.62e-03 1.54e-01
047 Conv2d           : 4.39e-03 1.54e-01
048 BatchNorm2d      : 4.39e-03 1.54e-01
049 Sum              : 6.18e-03 1.00e+00
050 ReLU             : 6.18e-03 1.00e+00
051 Conv2d           : 1.13e-02 8.91e-01
052 BatchNorm2d      : 1.13e-02 8.91e-01
053 ReLU             : 1.13e-02 8.91e-01
054 Conv2d           : 1.28e-02 8.91e-01
055 BatchNorm2d      : 1.28e-02 8.91e-01
056 Conv2d           : 3.19e-03 1.09e-01
057 BatchNorm2d      : 3.19e-03 1.09e-01
058 Sum              : 1.45e-02 1.00e+00
059 ReLU             : 1.45e-02 1.00e+00
060 Conv2d           : 1.16e-02 9.42e-01
061 BatchNorm2d      : 1.16e-02 9.42e-01
062 ReLU             : 1.16e-02 9.42e-01
063 Conv2d           : 8.87e-03 9.42e-01
064 BatchNorm2d      : 8.87e-03 9.42e-01
065 Sum              : 9.18e-03 1.00e+00
066 ReLU             : 9.18e-03 1.00e+00
067 AdaptiveAvgPool2d: 1.34e-01 1.00e+00
068 Linear           : 1.00e+00 1.00e+00

@rachtibat
Copy link
Contributor

rachtibat commented Jun 17, 2022

Hi Chris,

Thanks that's interesting. It confirms the correlation between heatmap quality and hidden concept attribution.

@maxdreyer
Copy link

maxdreyer commented Jun 18, 2022

I might add that this could be a problem of the gamma-rule, which does not fit (as it is) to the skip-connections of the ResNet model.
Usually, the gamma rule stabilizes the relevance propagation, as it favors positive contributions (by 1+gamma) and thereby enlarges the denominator (if the denominator is positive). Exploding relevances due to a small denominator are less likely.

Formula:
Bildschirmfoto 2022-06-18 um 13 21 19

However, if the denominator is negative, it becomes more positive with larger gamma. Then, the denominator can become smaller in magnitude and relevances can explode. This usually does not happen, as negative denominator means negative pre-activation, which is set to zero with ReLU non-linearities.
With skip-connections (where pre-activations are added together), a negative pre-activation can receive relevance when it is added to a larger positive pre-activation.

So the problem is, that the gamma-rule assumes, that only neurons with positive outputs have relevance/contributed in the forward pass. One way to fix it would be to make the gamma rule more generic, by (1) checking the sign of the output and then either favoring positive (if output pos.) or negative contributions (if output neg.). This way, the denominator always becomes larger.

A fix like that would probably also stabilize relevance propagation if other non-linearities are used, where negative outputs receive relevance/contribute in the forward pass. As I think of it, this could also be interesting for other rules, e.g. alpha-beta. Here, for alpha1_beta0 all positive contributions receive relevance if the output is negative. Would it not be more sensible if all negative contributions receive relevance?

@maxdreyer
Copy link

I have run a small experiment and making the gamma-rule symmetric (for pos. and neg. output, indicated by *) seems to result in much more reasonable results. The total relevance at the input stays bounded and heatmaps seem sensible.

lizard

Here is my implementation of the updated rule

class GammaSymmetric(BasicHook):

    def __init__(self, gamma=0.25):
        self.gamma = gamma

        def gradient_map(out_grad, outputs):
            P = outputs[0] + outputs[1]  # positive contributions
            N = outputs[2] + outputs[3]  # negative contributions

            alpha = 1 + gamma * (P >= N.abs()) # gamma for positive contr.
            beta = 1 + gamma * (P < N.abs())  # gamma for negative contr.

            out = P * alpha + N * beta

            factors = [alpha, alpha, beta, beta]

            return [fac * out_grad / stabilize(out) for fac in factors]

        super().__init__(
            input_modifiers=[
                lambda input: input.clamp(min=0),
                lambda input: input.clamp(max=0),
                lambda input: input.clamp(min=0),
                lambda input: input.clamp(max=0),
            ],
            param_modifiers=[
                lambda param, _: param.clamp(min=0),
                lambda param, name: param.clamp(max=0) if name != 'bias' else torch.zeros_like(param),
                lambda param, _: param.clamp(max=0),
                lambda param, name: param.clamp(min=0) if name != 'bias' else torch.zeros_like(param),
            ],
            output_modifiers=[lambda output: output] * 4,
            gradient_mapper=(lambda out_grad, outputs: gradient_map(out_grad, outputs)),
            reducer=(
                lambda inputs, gradients: (
                        (inputs[0] * gradients[0] + inputs[1] * gradients[1]) +
                        (inputs[2] * gradients[2] + inputs[3] * gradients[3])
                )
            )
        )

@rodrigobdz
Copy link
Contributor Author

rodrigobdz commented Jun 19, 2022

Thank you all for taking the time to look into this issue so thoroughly. I am yet to go in detail through the snippets and some comments but I am really inspired by the fruitful conversation we have going on right now.

A few thoughts came into my mind while reading the observations made so far:

  1. Vanishing contributions might be a different issue, because whenever the numerical instability happens, the relevances elsewhere do not vanish as it might seem from looking at the initial heatmap. Rather, the relevances elsewhere are significantly low in magnitude relative to the scores blown out of proportion. 



    See Comparison section of Additional insights in Numerical instability in ResNet50 heatmaps #148 (comment)

  2. Are we sure that the numerical instability affects exclusively the gamma rule?

    From an implementation perspective, the stabilize function is widely used across zennit's codebase.

    From a theoretical standpoint, the implementation of the generic LRP rule from section 10.2.2 in [1] affects LRP-zero/epsilon/gamma rules.

  3. An initial step towards a solution could be to warn the user that the heatmap might not be correct whenever the relevances blow out of proportion—i.e., when conservativity is violated. E.g., when sum is not (roughly) equal to one, as shown by @chr5tphr in Numerical instability in ResNet50 heatmaps #148 (comment).

    The warning itself certainly does not solve the actual issue but it would be more user-friendly. A step further would be to provide recommendations together with the warning 💯.

  4. The in-depth explanation of the possible root cause of the numerical instability sounds plausible. I am wary about the following suggestion, as I fear that it would go against the intrinsic purpose of the Gamma rule, which is to mainly favor positive contributions:

    One way to fix it would be to make the gamma rule more generic, by (1) checking the sign of the output and then either favoring positive (if output pos.) or negative contributions (if output neg.).
    Source: Numerical instability in ResNet50 heatmaps #148 (comment)

    The parameter γ controls by how much positive contributions are favored. As γ increases, negative contributions start to disappear.

    Source:
 [1]

    Nevertheless, the plots by @maxdreyer in Numerical instability in ResNet50 heatmaps #148 (comment) seem to correctly favor positive over negative contributions. 🤔


Food for thought


References:

[1] Montavon, Grégoire, Alexander Binder, Sebastian Lapuschkin, Wojciech Samek, and Klaus-Robert Müller. "Layer-wise relevance propagation: an overview." Explainable AI: interpreting, explaining and visualizing deep learning (2019): 193-209.


@leanderweber
Copy link

Specifically for ResNet50, there is also a pretty nasty artifact in modified backprop explanations caused by the 1x1conv stride 2 downsampling shortcuts. Short of excluding those shortcuts, there is not really a way to get around that. I suspect that this may further contribute to the concentrated attributions you are reporting (in addition to the issues with the gamma-rule). This artifact should also appear for other composites, but not for other models.

Below are some example heatmaps showing that artifact for PascalVOC.

resnet50-heatmap-artifact

@chr5tphr
Copy link
Owner

chr5tphr commented Jun 24, 2022

Hey everyone,

coming back to @maxdreyer's suggestion, it is true that the Gamme rule is only defined on positive inputs.
This previously was also the case for the ZPlus rule, which we decided to make general in #5.
As it turns out, there is a generalized gamma version (Supplement C.2).
Given that the original Gamma should not be used for negative inputs, and that the generalized version reduces to the original version when there are no negative inputs, I think I will modify the Gamma rule to the generalized version, as we did for the ZPlus rule.

@rodrigobdz
Copy link
Contributor Author

rodrigobdz commented Jul 9, 2022

For reference only:

I just made public my LRP implementation lrp-pf-auc and I'd like to share these Jupyter notebooks (1 and 2) with additional insights on the numerical instability.

@chr5tphr chr5tphr added the model compatibility Compatibility for new or variations of existing models label Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model compatibility Compatibility for new or variations of existing models
Projects
None yet
Development

No branches or pull requests

5 participants