-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
66 lines (53 loc) · 1.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import argparse
from src.attack import attack
from src.mnist import mnist_loaders, mnist_model
from src.fgsm import fgsm
from src.pgd import pgd
def argparser(
batch_size=50, path=None,
attack=None, epsilon=0.1, n_iters=40):
parser = argparse.ArgumentParser()
# model
parser.add_argument("--batch_size", type=int, default=batch_size)
parser.add_argument("--path", default=path)
# adversarial attack
parser.add_argument("--attack", default=attack)
# fgsm, pgd
parser.add_argument("--epsilon", type=float, default=epsilon)
# pgd
parser.add_argument("--n_iters", type=int, default=n_iters)
args = parser.parse_args()
return args
def mean(l):
return sum(l)/len(l)
if __name__ == "__main__":
args = argparser()
if args.path is None:
raise ValueError("NotFound Path")
model = mnist_model()
checkpoint = torch.load(args.path)
model.load_state_dict(checkpoint["state_dict"])
_, test_loader = mnist_loaders(args.batch_size)
opt = optim.Adam(model.parameters())
# attack
if args.attack == "fgsm":
attack = fgsm
kwargs = {
"epsilon": args.epsilon
}
elif args.attack == "pgd":
attack = pgd
kwargs = {
"epsilon": args.epsilon,
"n_iters": args.n_iters
}
else:
raise ValueError("Unknown attack")
total_acc, total_acc_atk = attack(test_loader, model, **kwargs)
print("Before Accuracy: {acc:.4f}, After Accuracy: {acc_atk:.4f}"\
.format(acc=mean(total_acc), acc_atk=mean(total_acc_atk)))