Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical Stability in Pytorch affect the effectiveness of attack? #63

Open
rowedenny opened this issue Feb 20, 2020 · 11 comments
Open

Comments

@rowedenny
Copy link

rowedenny commented Feb 20, 2020

Hello,
Thanks for the amazing work on Pytorch, since it is very similar to cleverhans, I recommend it to a lot of my friends.

I am using this lib to evaluate the robustness of my model, and I somehow see this issue from another repo
https://github.com/kleincup/DEEPSEC/issues/3

Basically, it reveals some unknown numerical stability of Pytorch. So I test whether it may also occur in this lib by comparing the output of the model with/without
logits = logits - torch.max(logits, dim=1, keepdim=True)[0]

For four-convolutional layers and 2 full-connection layer network on MNIST, the classification accuracy of the adversary on FGSM with eps=0.3 is 24.9%, but when I add the line above to fix the numerical stability, the classification accuracy of the adversary drops to 2.0%

To my best knowledge, the latter seems to be more reasonable.

I understand it is not related to the implementation of advertorch, yet what I see that the output for the models in the test_utils.py, none of them takes numerical stability into consideration. So if this issue may also occur in advertorch, may I propose to fix this in the tutorial?

@gwding
Copy link
Collaborator

gwding commented Feb 20, 2020

@rowedenny Thanks for raising this issue! I think it is definitely a potential problem, but I haven't encountered it personally. I tried to generate this problem with the following script

from advertorch.attacks import L2PGDAttack
from advertorch.attacks import LinfPGDAttack
from advertorch_examples.utils import TRAINED_MODEL_PATH
import os
import torch
import torch.nn as nn

from advertorch.utils import predict_from_logits
from advertorch_examples.utils import get_mnist_test_loader


class LeNet5(nn.Module):

    def __init__(self, offset=0., minus_max=False):
        super(LeNet5, self).__init__()
        self.offset = offset
        self.minus_max = minus_max
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(2)
        self.linear1 = nn.Linear(7 * 7 * 64, 200)
        self.relu3 = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(200, 10)

    def forward(self, x):
        out = self.maxpool1(self.relu1(self.conv1(x)))
        out = self.maxpool2(self.relu2(self.conv2(out)))
        out = out.view(out.size(0), -1)
        out = self.relu3(self.linear1(out))
        out = self.linear2(out)
        out = out + self.offset
        if self.minus_max:
            out = out - torch.max(out, dim=1, keepdim=True)[0]
        return out


if __name__ == '__main__':
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    filename = "mnist_lenet5_clntrained.pt"

    model = LeNet5()
    model.load_state_dict(
        torch.load(os.path.join(TRAINED_MODEL_PATH, filename)))
    model.to(device)
    model.eval()


    situations = [
        (0, False),
        (1000, False),
        (1000000, False),
        (100000000, False),
        (10000000000, False),
        (1000000000000, False),
        (0, True),
        (1000, True),
        (1000000, True),
        (100000000, True),
        (10000000000, True),
        (1000000000000, True),
    ]


    batch_size = 100
    loader = get_mnist_test_loader(batch_size=batch_size)
    for cln_data, true_label in loader:
        break
    cln_data, true_label = cln_data.to(device), true_label.to(device)


    for eps in [0.1, 0.2, 0.3]:
        eps_iter = eps / 4
        nb_iter = 100
        adversary = LinfPGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps,
            nb_iter=nb_iter, eps_iter=eps_iter,
            rand_init=False, clip_min=0.0, clip_max=1.0,
            targeted=False)

        for offset, minus_max in situations:
            model.offset = offset
            model.minus_max = minus_max
            adv = adversary.perturb(cln_data, true_label)

            model.offset = 0.
            model.minus_max = False
            advpred = predict_from_logits(model(adv))
            print(
                "eps: {}, offset: {}, minus_max: {}, acc: {:.2f}".format(
                    eps, offset, minus_max,
                    (advpred == true_label).float().mean().item()
                ))

and got the following output

eps: 0.1, offset: 0, minus_max: False, acc: 0.58
eps: 0.1, offset: 1000, minus_max: False, acc: 0.58
eps: 0.1, offset: 1000000, minus_max: False, acc: 0.59
eps: 0.1, offset: 100000000, minus_max: False, acc: 0.79
eps: 0.1, offset: 10000000000, minus_max: False, acc: 0.85
eps: 0.1, offset: 1000000000000, minus_max: False, acc: 0.85
eps: 0.1, offset: 0, minus_max: True, acc: 0.58
eps: 0.1, offset: 1000, minus_max: True, acc: 0.58
eps: 0.1, offset: 1000000, minus_max: True, acc: 0.58
eps: 0.1, offset: 100000000, minus_max: True, acc: 0.64
eps: 0.1, offset: 10000000000, minus_max: True, acc: 0.73
eps: 0.1, offset: 1000000000000, minus_max: True, acc: 0.73
eps: 0.2, offset: 0, minus_max: False, acc: 0.00
eps: 0.2, offset: 1000, minus_max: False, acc: 0.00
eps: 0.2, offset: 1000000, minus_max: False, acc: 0.00
eps: 0.2, offset: 100000000, minus_max: False, acc: 0.11
eps: 0.2, offset: 10000000000, minus_max: False, acc: 0.16
eps: 0.2, offset: 1000000000000, minus_max: False, acc: 0.16
eps: 0.2, offset: 0, minus_max: True, acc: 0.00
eps: 0.2, offset: 1000, minus_max: True, acc: 0.00
eps: 0.2, offset: 1000000, minus_max: True, acc: 0.00
eps: 0.2, offset: 100000000, minus_max: True, acc: 0.01
eps: 0.2, offset: 10000000000, minus_max: True, acc: 0.03
eps: 0.2, offset: 1000000000000, minus_max: True, acc: 0.03
eps: 0.3, offset: 0, minus_max: False, acc: 0.00
eps: 0.3, offset: 1000, minus_max: False, acc: 0.00
eps: 0.3, offset: 1000000, minus_max: False, acc: 0.00
eps: 0.3, offset: 100000000, minus_max: False, acc: 0.00
eps: 0.3, offset: 10000000000, minus_max: False, acc: 0.00
eps: 0.3, offset: 1000000000000, minus_max: False, acc: 0.00
eps: 0.3, offset: 0, minus_max: True, acc: 0.00
eps: 0.3, offset: 1000, minus_max: True, acc: 0.00
eps: 0.3, offset: 1000000, minus_max: True, acc: 0.00
eps: 0.3, offset: 100000000, minus_max: True, acc: 0.00
eps: 0.3, offset: 10000000000, minus_max: True, acc: 0.00
eps: 0.3, offset: 1000000000000, minus_max: True, acc: 0.00

it seems that the problem only exist in extreme cases.

So I'm curious on what are your logits values that cause this problem? and which attack did you use.

@gwding
Copy link
Collaborator

gwding commented Feb 20, 2020

another observation is that, if the offset is large enough, then even minus_max won't bring it down to original accuracy. I suspect this is because the precision of float number, but didn't investigate thoroughly.

@rowedenny
Copy link
Author

rowedenny commented Feb 20, 2020

Thanks for your quick response. I use the FGSM and eps =0.3.

I also follow the tutorial of cleverhans, but I also see the same issue. Though this issue is discovered by Carlini, they did not mention in cleverhans? So I did concern about how to evaluate the robustness of my model trustfully.

In addition, if it is necessary, I would like to provide my model to reproduce the case.

import torch.nn as nn
import torch.nn.functional as F
import torch


class ConvNet(nn.Module):
    def __init__(self, num_classes, return_intermediate=False):
        super(ConvNet, self).__init__()
        self.return_intermediate = return_intermediate

        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2))
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
        )

        self.layer4 = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        self.layer5 = nn.Sequential(
            nn.Linear(256, num_classes)
        )

        self.layers = [self.layer1, self.layer2, self.layer3, self.layer4]

    def forward(self, x):
        output = self.layer1(x)

        output = self.layer2(output)
        output = output.view(-1, 128 * 4 * 4)

        output = self.layer3(output)

        output = self.layer4(output)

        logits = self.layer5(output)
        logits = logits - torch.max(logits, dim=1, keepdim=True)[0]

        if self.return_intermediate:
            return logits, relus
        else:
            return logits

@gwding
Copy link
Collaborator

gwding commented Feb 20, 2020

@rowedenny Would be nice if you can provide your model checkpoint, I'd be interested to see what makes problem happen.

BTW, i also tried FGSM, it has similar properties to PGD attacks on this specific model

eps: 0.1, offset: 0, minus_max: False, acc: 0.79
eps: 0.1, offset: 1000, minus_max: False, acc: 0.79
eps: 0.1, offset: 1000000, minus_max: False, acc: 0.86
eps: 0.1, offset: 100000000, minus_max: False, acc: 0.88
eps: 0.1, offset: 10000000000, minus_max: False, acc: 0.90
eps: 0.1, offset: 1000000000000, minus_max: False, acc: 0.90
eps: 0.1, offset: 0, minus_max: True, acc: 0.79
eps: 0.1, offset: 1000, minus_max: True, acc: 0.79
eps: 0.1, offset: 1000000, minus_max: True, acc: 0.79
eps: 0.1, offset: 100000000, minus_max: True, acc: 0.82
eps: 0.1, offset: 10000000000, minus_max: True, acc: 0.84
eps: 0.1, offset: 1000000000000, minus_max: True, acc: 0.84
eps: 0.2, offset: 0, minus_max: False, acc: 0.17
eps: 0.2, offset: 1000, minus_max: False, acc: 0.30
eps: 0.2, offset: 1000000, minus_max: False, acc: 0.39
eps: 0.2, offset: 100000000, minus_max: False, acc: 0.63
eps: 0.2, offset: 10000000000, minus_max: False, acc: 0.48
eps: 0.2, offset: 1000000000000, minus_max: False, acc: 0.48
eps: 0.2, offset: 0, minus_max: True, acc: 0.18
eps: 0.2, offset: 1000, minus_max: True, acc: 0.18
eps: 0.2, offset: 1000000, minus_max: True, acc: 0.18
eps: 0.2, offset: 100000000, minus_max: True, acc: 0.22
eps: 0.2, offset: 10000000000, minus_max: True, acc: 0.29
eps: 0.2, offset: 1000000000000, minus_max: True, acc: 0.29
eps: 0.3, offset: 0, minus_max: False, acc: 0.05
eps: 0.3, offset: 1000, minus_max: False, acc: 0.07
eps: 0.3, offset: 1000000, minus_max: False, acc: 0.11
eps: 0.3, offset: 100000000, minus_max: False, acc: 0.28
eps: 0.3, offset: 10000000000, minus_max: False, acc: 0.08
eps: 0.3, offset: 1000000000000, minus_max: False, acc: 0.08
eps: 0.3, offset: 0, minus_max: True, acc: 0.00
eps: 0.3, offset: 1000, minus_max: True, acc: 0.00
eps: 0.3, offset: 1000000, minus_max: True, acc: 0.00
eps: 0.3, offset: 100000000, minus_max: True, acc: 0.01
eps: 0.3, offset: 10000000000, minus_max: True, acc: 0.02
eps: 0.3, offset: 1000000000000, minus_max: True, acc: 0.02

@rowedenny
Copy link
Author

rowedenny commented Feb 20, 2020

Much appreciation. I even doubt whether I am the only one who really encounters with this issue. I honestly appreciate your help if you would like to investigate this issue and resolve my concern.

I have attached the model checkpoint and a jupyter notebook in the following google drive.
https://drive.google.com/drive/folders/1dFD7n0JotIDuZTl4MWS0M9XiLr8O3S24?usp=sharing

You only need to update the MODEL_PATH to your local directory of the model checkpoint

I can consistently reproduce this issue.
The chunk 3 is expected to output as
[] Model Robustness Evaluation: Accuracy 27.370
and the chunk 4 is expected to output as
[
] Model Robustness Evaluation: Accuracy 2.290

@gwding
Copy link
Collaborator

gwding commented Feb 21, 2020

@rowedenny confirming that I can reproduce the problem. In your model, the gap between max and min logit values are larger than mine, maybe you trained it for very long time?

It seems that something like logits = logits / 10. would also work.

It could be worth adding something like this in advertorch, with some considerations in desgin.

@rowedenny
Copy link
Author

rowedenny commented Feb 21, 2020

maybe you trained it for very long time?

Yes. My model on MNIST is trained with 100 epochs, and also if I remember correctly, the tutorial on Cleverhans train with 6 epochs, and that may explain why they did not see this issue. I guess this is the similar case for advertorch.

It seems that something like logits = logits / 10. would also work.

I am not sure this may affect the attacks, like C&W. Since if we scale the logits, but the attacker is not "aware" the scale, and then the effectiveness of the attack could be affected?

It's my great honor talking with you, and I am pleased that finally, your confirmation confirms that I am not the only one who sees this.

@gwding
Copy link
Collaborator

gwding commented Feb 22, 2020

@rowedenny Thanks for the kind words! and for bringing up this interesting observation. This also reminds me another paper, Defensive Distillation is Not Robust to Adversarial Examples, where they actually use the logit scaling method to attack defensive distillation, which I think has similar problems to your model. The first author, C in C&W, is actually also who raised the DEEPSEC issue you quoted at the beginning.

I'll dig a bit more on this and see if there's a way to add some functionalities in advertorch for this.

@rowedenny
Copy link
Author

I'll dig a bit more on this and see if there's a way to add some functionalities in advertorch for this.

I think a quick fixup is to shift the logit. Since minus maximum will lead the largest number be zero, so may I propose to check if the torch.max(logit, dim=1) == 1 and then if not, just do the minus maimum?

@gwding
Copy link
Collaborator

gwding commented Feb 24, 2020

Not sure if this is the optimal solution, but it at least sounds like a reasonable choice to have. I would suggest to add a "loss function wrapper" (could be a decorator) in advertorch.utils, such that it does this preprocessing before the logits go into a common loss function, say CrossEntropyLoss.
Feel free to start a PR. or I can also make one and include you as the reviewer @rowedenny

@rowedenny
Copy link
Author

rowedenny commented Feb 24, 2020

Hi @gwding, I think we can inherit the ideas from cleverhens that checks whether the prediction is from logits, or probability after softmax. I assume it is a similar case here. Please refer to the following code

https://github.com/tensorflow/cleverhans/blob/master/cleverhans/model.py#L228
And I would love to jointly take a look at this solution.

B.T.W I drop you an email and look forward to hearing from you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants