Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Densenet canonizations #171

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft

Densenet canonizations #171

wants to merge 11 commits into from

Conversation

gumityolcu
Copy link

Hello,

Here are a summary of the contributions:

  1. The epsilon values of the batch_norm layers used to be left as they are, when they need to be set to zero for perfect canonization. I fixed this. Note that the epsilon parameter can not be added to the batch_norm_params field of the canonizer. This is because it is a literal, not a torch Variable. So if you add the epsilon parameter there, the code will try to reach batch_norm.eps.data, which does not exits, when trying to restore it. Therefore, i use a new class variable "batch_norm_eps" to remember it during canonization.
  2. CompositeCanonizer now returns the list of handles reversed. This is because, if we have a two canonizers attaching to a module, then we need to detach them in the reverse order that they are applied, in order to restore the original values. I opted to reverse the list in the class because detaching the given handles in returned order seemed more user friendly. And I couldn't think of a use case where this would cause problems.
  3. MergeBatchNormtoRight canonizer is added. This merges a batch normalization layer to a linear layer that comes after it. If the linear layer is a convolutional layer with padding, this is not straightforward. A full feature map needs to be added to the output of the layer instead of a simple bias. This is done by adding forward hooks.
  4. ThreshReLUMergeBatchNorm is added. This canonizer detects BN->ReLU->Linear and changes the activation function to a function that depends on the batch norm variables to get the BatchNorm after the activation. Then the batchnorm is merged to the linear layer that is next to it. This is as described in https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet/blob/master/canonization_doc.pdf

Further more BN->ReLU->AvgPool->Linear chains are found and canonized using the same method, because Batch normalization commutes with average pooling.
6.Full proposed canonizers are added to torchvision.py. Another addition is DenseNetAdaptiveAvgPoolCanonizer which is needed before applying other canonizers to densenets. It makes the final ReLU and AvgPooling layers of torchvision densenet objects explicit. By default, these are applied in the forward method of the model, not as nn.module objects.

Thank you very much and I am looking forward to any kind of feedback!

Galip Ümit Yolcu and others added 11 commits September 15, 2022 16:26
- The returned handles are reversed.
- This way when two canonizers change a parameter, removing handles in the returned order will restore the original model
- The epsilon parameter is set to 0 during canonization
- Parameter dimensions are checked before merging, to prevent from attempting merging incompatible layers as in DenseNets.
- Minor change in MergeBatchNorm: set batch_norm.eps=0 in the register method instead of merge_batch_norm
- Add MergeBatchNormtoRight canonizer
  - Merges BathNorm to a linear layer to the right.
  - If the convolution has padding, one needs to compute a feature map and add it to the output of the convolution
    to account for the batch norm bias
- Canonizer didn't work correctly when the convolution has bias. This has been handled
- The hook function was made lighter by discarding unneeded overhead computation
- Minor change in MergeBatchNormtoRight: remove unused variable
- Add DenseNetAdaptiveAvgPoolCanonizer: makes the last adaptive average pooling of torchvision densenets an explicit nn.module object
- Add ThreshReLUMergeBatchNorm: Canonizer to canonize BatchNorm -> ReLU -> Linear chains. Adds backwards and forward hooks to ReLU in order to turn it into ThreshReLU as defined in https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet/blob/master/canonization_doc.pdf
- Add SequentialThreshCanonizer: a composite canonizer that applies DenseNetAvgPoolCanonizer, SequentialMergeBatchNorm, ThreshReLUMergeBatchNorm
- Add ThreshSequentialCanonizer: a composite canonizer that applies DenseNetAvgPoolCanonizer, ThreshReLUMergeBatchNorm, SequentialMergeBatchNorm

The last two canonizers are the recommended canonizers for torchvision DenseNet implementations. We need to apply the standard SequentialMergeBatchNorm to do away with the initial BN->Conv in the architecture. The two canonizers result in different implementations of the same function because in practice dense blocks have BN->ReLU->Conv->BN->ReLU->Conv which leaves the possibility of using the SequentialMergeBatchNorm inside the DenseBlocks if it is applied before. In practice, both canonizations get rid of the artifacts in the attribution maps. SequentialThreshCanonizer seems to be better quantitatively.
…orrectCompositeCanonizer in the ThreshSequentialCanonizer and SequentialThreshCanonizer classes
Docs: Fix docstrings in MergeBatchNormtoRight and ThreshReLUMergeBatchNorm
Copy link
Owner

@chr5tphr chr5tphr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Galip,

sorry for the very long hold-up. Let's try to finalize this. Ultimately, we need to rebase this. Maybe you can first introduce the changes and then rebase.

Comment on lines +342 to +343
module.canonization_params = {}
module.canonization_params["bias_kernel"] = bias_kernel
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's store these in the canonizer itself, similar to MergeBatchNorm.linear_params and .batch_norm_params

Comment on lines +346 to +353
module.bias.data = (original_weight * shift).sum(dim=1) + original_bias

# change batch_norm parameters to produce identity
batch_norm.running_mean.data = torch.zeros_like(batch_norm.running_mean.data)
batch_norm.running_var.data = torch.ones_like(batch_norm.running_var.data)
batch_norm.bias.data = torch.zeros_like(batch_norm.bias.data)
batch_norm.weight.data = torch.ones_like(batch_norm.weight.data)
batch_norm.eps = 0.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these need to be adapted to the new approach (see current version of MergeBatchNorm)


module.canonization_params = {}
module.canonization_params["bias_kernel"] = bias_kernel
return_handles.append(module.register_forward_hook(MergeBatchNormtoRight.convhook))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of not using Hooks, maybe we can wrap and overwrite the forward function (similar to the ResNet Canonizer)?

Comment on lines +336 to +338
temp_module = torch.nn.Conv2d(in_channels=module.in_channels, out_channels=module.out_channels,
kernel_size=module.kernel_size, padding=module.padding,
padding_mode=module.padding_mode, bias=False)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's indent one line per kwarg


if isinstance(module, torch.nn.Conv2d):
if module.padding == (0, 0):
module.bias.data = (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be adapted to object.__setattr__(module, 'bias', (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias)

of instance, which is why deleting instance attributes with the same name reverts them to the original
function.
'''
self.module.features = Sequential(*list(self.module.features.children())[:-2])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, you can slice Sequential, i.e. self.module.feature = self.module.features[:-2]

'''
return DenseNetAdaptiveAvgPoolCanonizer()

def register(self, module, attributes):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

for key in self.attribute_keys:
delattr(self.module, key)

def forward(self, x):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

return out


class DenseNetSeqThreshCanonizer(CompositeCanonizer):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

))


class DenseNetThreshSeqCanonizer(CompositeCanonizer):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants