Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model inference outside of composite context #200

Open
MikiFER opened this issue Sep 1, 2023 · 5 comments
Open

Model inference outside of composite context #200

MikiFER opened this issue Sep 1, 2023 · 5 comments
Labels
question Further information is requested testing Improvements, additions, or issues with tests

Comments

@MikiFER
Copy link

MikiFER commented Sep 1, 2023

Hi @chr5tphr ,

I'm creating a new issue regarding the inference of a model.
I noticed that when I infer a model outside of (before) composite context (which has appropriate model cannonizer) I do not obtain the same attribution as when the inference is done inside of the context. This has me concerned because in order to properly learn batch norm's parameters inference should be done outside of the context because context effectively creates identity out of batch norm so inference inside of the context would never update it's parameter values. Is there something I'm not understanding here correctly?

Here is a code snippet I used to validate that attribution is not the same. Same test was also conducted on vgg16 model and yielded the same result.

import torch
from torchvision.models import resnet18

from zennit.composites import EpsilonPlusFlat
from zennit.torchvision import ResNetCanonizer

from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop
from torchvision.transforms import ToTensor, Normalize

import matplotlib.pyplot as plt

# define the base image transform
transform_img = Compose([
    Resize(256),
    CenterCrop(224),
])
# define the normalization transform
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# define the full tensor transform
transform = Compose([
    transform_img,
    ToTensor(),
    transform_norm,
])
# load the image
image = Image.open('dornbusch-lighthouse.jpg')
# transform the PIL image and insert a batch-dimension
data = transform(image)[None]
data.requires_grad = True
# define target
target = torch.eye(1000)[[437]]


model = resnet18()

canonizers = [ResNetCanonizer()]
composite = EpsilonPlusFlat(canonizers=canonizers)

# Inference before context
model.eval()  # Put model in eval so batch-norm is frozen
model_out_before = model(data)
with composite.context(model) as modified_model:
    attribution_before, = torch.autograd.grad(model_out_before, data, target)

# Inference inside context
with composite.context(model) as modified_model:
    model_out_in = modified_model(data)
    attribution_in, = torch.autograd.grad(model_out_in, data, target)

relevance_before = attribution_before.cpu().sum(1).squeeze(0).numpy()
relevance_in = attribution_in.cpu().sum(1).squeeze(0).numpy()

plt.figure(figsize=(15, 5))
plt.subplot(1,3,1)
plt.imshow(transform_img(image))
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(relevance_before)
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(relevance_in)
plt.axis('off')
plt.show()
@chr5tphr
Copy link
Owner

chr5tphr commented Sep 1, 2023

Hey @MikiFER

thanks a lot for writing the issue!

This is intended behavior due to how the gradient-modification is implemented. I will explain a little bit why this is the case by describing the inner-workings.

Why is this intended behavior?

composite.context calls composite.register upon entering the context, which sweeps through the model and finds matching rules for each layer, where each matching rule will cause the creation of a new hook, which is attached to a module with hook.register.

What happens in hook.register is the real mechanic of how the gradient is modified: A pre-forward and forward hook is attached to the module, both of which introduce a new node (Identity) into the computational graph, one at the input of the module, and one at its output. These nodes the get attached gradient hooks, which change the gradient at the time the gradient passes through these additional nodes. The gradient hook at the output node's job is solely to record the gradient at the output of the module so it can be passed to the input node's hook (this is the happening in Hook.post_forward, which attaches a wrapper that calls Hook.pre_backward) . Then, when the gradient subsequently reaches the input node of the module, its respective gradient hook takes the gradient at the output, and then computes a custom, modified gradient which it uses to overwrite its gradient (this happens in Hook.pre_forward, which attaches a wrapper that calls Hook.backward).

This means that the foundation to modify the gradient is laid in the forward pass, which only happens if the modules have the forward-hooks attached which introduce these modifications.

If you do the forward pass before registering the hooks (i.e., using composite.context), the resulting computational graph will not have the changes included that change the gradient in the backward pass. Thus, your first case should simply provide you with the gradient, while the second case would correctly modify your model's gradient to give you the EpsilonPlusFlat feature attribution.

Parameter Learning under Registered Rules

The current implementation does not overwrite the parameters of the batch-norm and conv-layer in the canonization. Rather, it creates a new set of parameters, where the parameters of the conv-layer are computed from the conv-layer and the batch-norm layer, and thus creating a computational graph with the correct dependencies, in order to be able to provide meaningful gradients for the original parameters of the model when computing the gradient wrt. the feature attribution.

However, since the gradient is overwritten inside the context, this means that the computed gradients are actually not gradients, but feature attributions. Even worse, in the current implementation, the modified gradients of parameters inside the composite context are not feature attributions, but only the partial derivative of their respective layer multiplied with the output relevance, i.e. $\frac{\partial g(a; w, b)}{\partial w_{ji}}R_j$ where $w_{ji}$ are the weights and $R_j$ is the feature attribution at the output of the layer.

This means that if you would like to compute feature attributions during training, you need to call the forward pass two times, one time to compute the gradients wrt. the parameters using torch.tensor.backward, and one time to compute the feature attribution, where the first case does not activate the rules.

However, there are two ways of deactivating the rules, and due to its implementation, the gradients wrt. parameters should, as far as I am thinking right now, be correctly computed even if the Canonizer is attached without the rules active.
This specifically happens in the case where we deactivate the rules inside the composite context, which is a mechanic primarily intended for second-order-derivatives. But by saving the backward-graph (using create_graph=True), and doing two backward steps (and only one forward), we should be able to compute the feature attribution, and gradients, as intended:

from zennit.attribution import Gradient
from zennit.composites import EpsilonGammaBox

# any composites support second order gradients
composite = EpsilonGammaBox(low=-3., high=3.)

with Gradient(model, composite) as attributor:
    output, grad = attributor(input, torch.ones_like, create_graph=True)

    # temporarily disable all hooks registered by the attributor's composite
    with attributor.inactive():
        loss = lossfn(output, target)
        optim.zero_grad()
        loss.backward()
        optim.step()

Another way that should work if the canonizers work as intended:

from zennit.attribution import Gradient
from zennit.composites import EpsilonGammaBox

# any composites support second order gradients
composite = EpsilonGammaBox(low=-3., high=3.)

with Gradient(model, composite) as attributor:
    output, grad = attributor(input, torch.ones_like, create_graph=True)

# leaving the context should also disable all hooks registered by the attributor's composite, although the saved foward graph will still include the modifications introduced by the canonizer
loss = lossfn(output, target)
optim.zero_grad()
loss.backward()
optim.step()

Opportunity for Contribution

If you have the time, would you be willing to try this out by comparing the parameter-gradients of a model inside a composite context with the .inactive() mechanic/ outside the context versus outside a composite context with an additional foward pass and report on your observations? Remember to also set model.eval(), such that the gradients of dropout/batchnorm/etc. are deterministic. If you feel up to the job, a next step would be then be to include some tests, just to be extra sure. But no worries if you are not available. Thanks a lot!

Hope this is answers your questions. If some points were unclear, or if you have further questions, I would be happy to answer!

Best,
Christopher

@MikiFER
Copy link
Author

MikiFER commented Sep 1, 2023

Hi @chr5tphr thank you for the detailed explanation. I would like to give it a try to get myself little bit more familiarized with library since I want to use it to test out something.

I'm afraid I will need some more explanation. I've read your response multiple times and I still am not 100% sure that I understand how it works so here are a couple more questions:

  1. About these 2 new Identity nodes that "sandwich" the original module that are created in the forward pass. Lets call them pre-Identity and post-Identity based on their relative position to the original module from the perspective of the forward pass. So in the backward pass post-Identity will receive the gradient from the pre-Identity node of the module following the module that we are currently observing. Will that gradient be attribution or the real gradient? Pre-Identity will receive the gradient of the observed module (which will not be modified in any way?) and modified gradient from the post-Identity which will then be used to calculate observed module's attribution and will be passed to the previous module?

  2. In your code examples you are first calling attributor with input and starting gradient (relevance) for the output layer and with option to create computation graph. I believe that is just syntatic sugar for:

output = model(input)
grad, _ = torch.autograd.grad(output, input,  torch.ones_like(output), create_graph=True)

That means you have called the grad function and computed gradient of the parameters (in this case it is not gradient but attribution) and then you call optimizer.zero_grad() outside of the context to remove that gradient. Then you use model output to compute some loss and call backward on it which again initiates the backward pass but now there are no custom gradient hooks and identities that sandwich modules so the gradient is really the gradient. Essentially you have done two backward passes one for obtaining the attribution and the second one for obtaining the real gradient used for learning. Why are you passing option create_graph=True. As I understand that would be used when it would be necessary to obtain gradient of the attribution? And when you then call backward function on the loss the whole part of the computation graph used to calculate the attribution is not used (it's like you start from the middle of the graph)?

  1. I did not understand the part with two forward passes. Do you mean that I would need to do something like this:
with Gradient(model, composite) as attributor:
    output, grad = attributor(input, torch.ones_like)  # first forward pass to obtain attribution and consequential backward pass
optim.zero_grad()  # remove the "gradients" calculated during attribution backward pass
model_output = model(input)  # second forward pass
loss = lossfn(output, target)  # calculate loss
loss.backward()  # calculate the gradient
optim.step()  # apply optimization step
optim.zero_grad()  # remove the calculated gradient which I will not do in the test so I can see their values

Then you would like me to test if the gradients of the parameters obtained at the end of this code block is the same as the gradient of the parameters obtained when using one of the two code blocks you have sent in your response?

  1. For my particular use case I will need to compute multiple attributions (for different output classes and different starting relevance and only some parts of the input will have that relevance calculated lets say only 3/8 batch examples) and then use those attributions in combination with the normal task loss to optimize the model. Would that mean that in my case code should look something like this?
with composite.context(model) as modified_model:
    output = modified_model(input)
    filtered_output, filtered_input = filter_output_and_input(output, input)  # Get input and output only for some examples in batch
    attribution_1, _ = torch.autograd.grad(filtered_output, filtered_input, starting_relevance_1, create_graph=True)
    attribution_2, _ = torch.autograd.grad(filtered_output, filtered_input, starting_relevance_2, create_graph=True)

loss = lossfn(output, attribution_1, attribution_2, target)
optim.zero_grad()
loss.backward()
optim.step()

@chr5tphr
Copy link
Owner

chr5tphr commented Sep 4, 2023

Hey @MikiFER ,

  1. About these 2 new Identity nodes that "sandwich" ...

This is precisely what happens! In the modified backward pass, the gradient will act as the attribution. Thus, all gradient functions, when rules are active, will instead of the real gradient, get the modified gradient, which is the attribution. Note here that if you choose to not modify the gradient of the layer, its true gradient function will be used, but still it will obtain the modified gradient from its following layer.

2. In your code examples you are first calling attributor...

Yes, attributor is just as syntactic sugar here. optim.zero_grads is solely zero-ing out the .grad attribute of all parameter tensors for the purpose of loss.backward(), since .backward will accumulate (add) gradients in each tensor's .grad attribute, for all tensors that have requires_grad=True. Since we are using torch.autograd.grad to compute the (modified) gradient of output wrt. input, no gradients are accumulated in the tensor's .grad, but instead we get as many new tensors as we supplied as inputs int torch.autograd.grad.

create_graph=True will create the graph to allow the computation of higher order gradients, as well as allow to compute gradients again on the same graph, which will normally be destroyed after a single backward. This way, you do not need to do two forward passes, because you can simply reuse the same graph twice, once with modifying the gradients, and once without modifying the gradients. In your specific use-case, you also need this for higher-order gradients, since you want to compute the gradient of an attribution wrt. the parameters.

I would recommend you to have a look at torch.tensor.backward and autograd mechanics of PyTorch if you would like to understand these mechanics a little better.

3. I did not understand the part with two forward passes. Do you mean that I would need to do something like this:

Almost! You need to do a backward pass once inside the context, but with the gradient modification disabled, and once as usual, and then compare the gradients.

from zennit.composites import EpsilonGammaBox

canonizers = [...] # specify canonizer

# any composites support second order gradients
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)

model.eval()

# with composite
with composite.context(model, composite) as modified:
    with attributor.inactive():
        output = modified(input)
        loss = lossfn(output, target)
        optim.zero_grad()
        loss.backward()
        with_comp = [param.grad[:] for param in next(optim.param_groups)['params']]

# business as usual
output = modified(input)
loss = lossfn(output, target)
optim.zero_grad()
loss.backward()
no_comp = [param.grad[:] for param in next(optim.param_groups)['params']]

print(all(torch.allclose(a, b) for a, b in zip(with_comp, no_comp)))
  1. For my particular use case I will need ...

Yes, that looks correct!

@chr5tphr chr5tphr added testing Improvements, additions, or issues with tests question Further information is requested labels Sep 4, 2023
@MikiFER
Copy link
Author

MikiFER commented Sep 13, 2023

Hi @chr5tphr thank you for the extensive reply. Sorry for slow response I was on vacation.

Could you elaborate a little more on this sentence: "no gradients are accumulated in the tensor's .grad, but instead we get as many new tensors as we supplied as inputs int torch.autograd.grad.". Isn't gradient of the parameter always of the same shape so wouldn't passing multiple inputs (batches) to torch.autograd.grad result in them accumulating in the same tensor not in saving multiple different .grad tensors?

I would like to try to integrate this code you supplied into a test. Would the best location for it be test_canonizers.py or in test_core.py.

On the side note I must inform you that I have full time job and only have free time to working on my research on Fridays so that is why my responses and coding may be slow.

@chr5tphr
Copy link
Owner

Hey @MikiFER

sorry for the even more belated response, I was exceptionally busy with my PhD thesis and just was not able to find any time for Zennit.

Could you elaborate a little more on this sentence: "no gradients are accumulated in the tensor's .grad, but instead we get as many new tensors as we supplied as inputs int torch.autograd.grad.". Isn't gradient of the parameter always of the same shape so wouldn't passing multiple inputs (batches) to torch.autograd.grad result in them accumulating in the same tensor not in saving multiple different .grad tensors?

If any input tensors in your computational graph requires a gradient, they will be connected to any new output tensor in the resulting computational graph. Calling Tensor.backward on some output tensor will resolve the full computational graph, find all input tensors that required a gradient from which the output tensor (transitively) resulted, and store their respective gradients in the .grad instance attribute of the tensor. The behavior of torch.autograd.grad is similar, but with two key differences: (1) instead of computing the gradient wrt. all input tensors that required a gradient, only the gradient wrt. the tensors that you supplied as a second argument will be computed, and (2) instead of accumulating these gradients in the respective tensor's .grad instance attribute, they will be returned as a tuple in the same order as you supplied them to torch.autograd.grad, so they will never be a sum of multiple gradient computations.

I would like to try to integrate this code you supplied into a test. Would the best location for it be test_canonizers.py or in test_core.py.

Conceptually, I think the best location without canonizers would be test_core.py, and with canonizers in test_canonizers.py. To not be too redundant, you could also simply include the test without canonizers in test_canonizers.py.

On the side note I must inform you that I have full time job and only have free time to working on my research on Fridays so that is why my responses and coding may be slow.

Thanks for letting me know and thanks a lot for your contribution! I have been rather busy myself, so no worries on that end at least from my side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested testing Improvements, additions, or issues with tests
Projects
None yet
Development

No branches or pull requests

2 participants