-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model inference outside of composite context #200
Comments
Hey @MikiFER thanks a lot for writing the issue! This is intended behavior due to how the gradient-modification is implemented. I will explain a little bit why this is the case by describing the inner-workings. Why is this intended behavior?
What happens in This means that the foundation to modify the gradient is laid in the forward pass, which only happens if the modules have the forward-hooks attached which introduce these modifications. If you do the forward pass before registering the hooks (i.e., using Parameter Learning under Registered RulesThe current implementation does not overwrite the parameters of the batch-norm and conv-layer in the canonization. Rather, it creates a new set of parameters, where the parameters of the conv-layer are computed from the conv-layer and the batch-norm layer, and thus creating a computational graph with the correct dependencies, in order to be able to provide meaningful gradients for the original parameters of the model when computing the gradient wrt. the feature attribution. However, since the gradient is overwritten inside the context, this means that the computed gradients are actually not gradients, but feature attributions. Even worse, in the current implementation, the modified gradients of parameters inside the composite context are not feature attributions, but only the partial derivative of their respective layer multiplied with the output relevance, i.e. This means that if you would like to compute feature attributions during training, you need to call the forward pass two times, one time to compute the gradients wrt. the parameters using However, there are two ways of deactivating the rules, and due to its implementation, the gradients wrt. parameters should, as far as I am thinking right now, be correctly computed even if the Canonizer is attached without the rules active. from zennit.attribution import Gradient
from zennit.composites import EpsilonGammaBox
# any composites support second order gradients
composite = EpsilonGammaBox(low=-3., high=3.)
with Gradient(model, composite) as attributor:
output, grad = attributor(input, torch.ones_like, create_graph=True)
# temporarily disable all hooks registered by the attributor's composite
with attributor.inactive():
loss = lossfn(output, target)
optim.zero_grad()
loss.backward()
optim.step() Another way that should work if the canonizers work as intended: from zennit.attribution import Gradient
from zennit.composites import EpsilonGammaBox
# any composites support second order gradients
composite = EpsilonGammaBox(low=-3., high=3.)
with Gradient(model, composite) as attributor:
output, grad = attributor(input, torch.ones_like, create_graph=True)
# leaving the context should also disable all hooks registered by the attributor's composite, although the saved foward graph will still include the modifications introduced by the canonizer
loss = lossfn(output, target)
optim.zero_grad()
loss.backward()
optim.step() Opportunity for ContributionIf you have the time, would you be willing to try this out by comparing the parameter-gradients of a model inside a composite context with the Hope this is answers your questions. If some points were unclear, or if you have further questions, I would be happy to answer! Best, |
Hi @chr5tphr thank you for the detailed explanation. I would like to give it a try to get myself little bit more familiarized with library since I want to use it to test out something. I'm afraid I will need some more explanation. I've read your response multiple times and I still am not 100% sure that I understand how it works so here are a couple more questions:
output = model(input)
grad, _ = torch.autograd.grad(output, input, torch.ones_like(output), create_graph=True) That means you have called the grad function and computed gradient of the parameters (in this case it is not gradient but attribution) and then you call
with Gradient(model, composite) as attributor:
output, grad = attributor(input, torch.ones_like) # first forward pass to obtain attribution and consequential backward pass
optim.zero_grad() # remove the "gradients" calculated during attribution backward pass
model_output = model(input) # second forward pass
loss = lossfn(output, target) # calculate loss
loss.backward() # calculate the gradient
optim.step() # apply optimization step
optim.zero_grad() # remove the calculated gradient which I will not do in the test so I can see their values Then you would like me to test if the gradients of the parameters obtained at the end of this code block is the same as the gradient of the parameters obtained when using one of the two code blocks you have sent in your response?
with composite.context(model) as modified_model:
output = modified_model(input)
filtered_output, filtered_input = filter_output_and_input(output, input) # Get input and output only for some examples in batch
attribution_1, _ = torch.autograd.grad(filtered_output, filtered_input, starting_relevance_1, create_graph=True)
attribution_2, _ = torch.autograd.grad(filtered_output, filtered_input, starting_relevance_2, create_graph=True)
loss = lossfn(output, attribution_1, attribution_2, target)
optim.zero_grad()
loss.backward()
optim.step() |
Hey @MikiFER ,
This is precisely what happens! In the modified backward pass, the gradient will act as the attribution. Thus, all gradient functions, when rules are active, will instead of the real gradient, get the modified gradient, which is the attribution. Note here that if you choose to not modify the gradient of the layer, its true gradient function will be used, but still it will obtain the modified gradient from its following layer.
Yes,
I would recommend you to have a look at torch.tensor.backward and autograd mechanics of PyTorch if you would like to understand these mechanics a little better.
Almost! You need to do a backward pass once inside the context, but with the gradient modification disabled, and once as usual, and then compare the gradients. from zennit.composites import EpsilonGammaBox
canonizers = [...] # specify canonizer
# any composites support second order gradients
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
model.eval()
# with composite
with composite.context(model, composite) as modified:
with attributor.inactive():
output = modified(input)
loss = lossfn(output, target)
optim.zero_grad()
loss.backward()
with_comp = [param.grad[:] for param in next(optim.param_groups)['params']]
# business as usual
output = modified(input)
loss = lossfn(output, target)
optim.zero_grad()
loss.backward()
no_comp = [param.grad[:] for param in next(optim.param_groups)['params']]
print(all(torch.allclose(a, b) for a, b in zip(with_comp, no_comp)))
Yes, that looks correct! |
Hi @chr5tphr thank you for the extensive reply. Sorry for slow response I was on vacation. Could you elaborate a little more on this sentence: "no gradients are accumulated in the tensor's .grad, but instead we get as many new tensors as we supplied as inputs int torch.autograd.grad.". Isn't gradient of the parameter always of the same shape so wouldn't passing multiple inputs (batches) to torch.autograd.grad result in them accumulating in the same tensor not in saving multiple different .grad tensors? I would like to try to integrate this code you supplied into a test. Would the best location for it be test_canonizers.py or in test_core.py. On the side note I must inform you that I have full time job and only have free time to working on my research on Fridays so that is why my responses and coding may be slow. |
Hey @MikiFER sorry for the even more belated response, I was exceptionally busy with my PhD thesis and just was not able to find any time for Zennit.
If any input tensors in your computational graph requires a gradient, they will be connected to any new output tensor in the resulting computational graph. Calling
Conceptually, I think the best location without canonizers would be
Thanks for letting me know and thanks a lot for your contribution! I have been rather busy myself, so no worries on that end at least from my side. |
Hi @chr5tphr ,
I'm creating a new issue regarding the inference of a model.
I noticed that when I infer a model outside of (before) composite context (which has appropriate model cannonizer) I do not obtain the same attribution as when the inference is done inside of the context. This has me concerned because in order to properly learn batch norm's parameters inference should be done outside of the context because context effectively creates identity out of batch norm so inference inside of the context would never update it's parameter values. Is there something I'm not understanding here correctly?
Here is a code snippet I used to validate that attribution is not the same. Same test was also conducted on vgg16 model and yielded the same result.
The text was updated successfully, but these errors were encountered: