/
utils_eval.py
47 lines (35 loc) · 1.26 KB
/
utils_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""
Utilities for PGD plus evaluation.
Based on code from https://github.com/yaircarmon/semisup-adv
"""
import os
import numpy as np
from models.wrn_madry import Wide_ResNet_Madry
from models.resnet import *
from models.small_cnn import SmallCNN
import torch
from torch.nn import Sequential, Module
def get_model(name, num_classes=10, normalize_input=False):
name_parts = name.split('-')
if name_parts[0] == 'wrn':
depth = int(name_parts[1])
widen = int(name_parts[2])
model = Wide_ResNet_Madry(
depth=depth, num_classes=num_classes, widen_factor=widen)
elif name_parts[0] == 'small':
model = SmallCNN()
elif name_parts[0] == 'resnet':
model = ResNet18()
else:
raise ValueError('Could not parse model name %s' % name)
if normalize_input:
model = Sequential(NormalizeInput(), model)
return model
class NormalizeInput(Module):
def __init__(self, mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)):
super().__init__()
self.register_buffer('mean', torch.Tensor(mean).reshape(1, -1, 1, 1))
self.register_buffer('std', torch.Tensor(std).reshape(1, -1, 1, 1))
def forward(self, x):
return (x - self.mean) / self.std