-
Notifications
You must be signed in to change notification settings - Fork 1
/
anomaly_detection.py
141 lines (114 loc) · 5.16 KB
/
anomaly_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import sys
import torch as ch
import numpy as np
from PIL import Image
import torch.nn as nn
from datetime import datetime
import torch.nn.functional as F
from parsing import anomaly_parser
import torchvision.models as models
from utils.ad_utils import one_vs_all
from utils.data_utils import gen_transform
from utils.core_utils import( fix_seeds, stdout_logger, make_single_class_loader,
InputScaling, InputDescaling, InputDenormalize)
from models.alexnet import( AlexNet, AlexNetGen, AlexNetEncDec,
AlexNetEncDec_config, AlexNetComp, AlexNetComp_config)
# Parse arguments
parser= anomaly_parser()
args = parser.parse_args()
# Extra arguments
args.exp_name= datetime.now().strftime( "%Y_%m_%d_%H_%M_%S")
args.out_folder= os.path.join( args.out_dir, args.exp_name, "output")
if not os.path.exists( args.out_folder): os.makedirs( args.out_folder) # create outliers folder
if args.seed: fix_seeds( args.seed) # fix pytorch random seed
# Set device
device= ch.device( 'cuda' if ch.cuda.is_available() else 'cpu')
if args.stdout_logger:
# Set stdout
args.stdout_str= os.path.join( args.stdout_dir, args.exp_name + '.txt')
sys.stdout= stdout_logger( stdout_str= args.stdout_str)
# Classifier settings
if args.data== "ImageNet":
args.num_classes= 1000
args.mean= ch.tensor( [ 0.485, 0.456, 0.406])
args.std=ch.tensor( [ 0.229, 0.224, 0.225])
else: raise ValueError( "Undefined dataset. Check 'data' input argument.")
# Set data transformations
transform_test, norm_flag= gen_transform( mode= args.transform_test,
init_dim= args.transform_init_dim)
# Create dataloader
if args.one_vs_all:
_, in_test_loader= make_single_class_loader( dataset= args.in_class_dataset,
data= args.in_data_path,
samples= args.in_samples,
workers= args.num_workers,
batch_size= args.batch_size,
transform_train= transform_test,
transform_test= transform_test,
random_samples= args.random_samples)
else:
raise ValueError( "Undefined anomaly detection strategy. Check 'one_vs_all' argument.")
# Load autoencoder
if args.generator_arch== "alexnet":
model= AlexNetEncDec( num_classes= args.num_classes,
mean= args.mean,
std= args.std,
output_layer= args.output_layer,
upsample_mode= args.upsample_mode,
spectral_init= args.spectral_init)
# Set layers and load checkpoint
AlexNetEncDec_config( classifier= model.classifier,
generator= model.generator,
load_classifier= args.load_classifier,
load_generator= args.load_generator,
output_layer= args.output_layer)
if norm_flag:
# Replace normalization by scaling
model.normalize= InputScaling()
model.denormalize= InputDescaling()
else:
model.denormalize= InputDenormalize( new_mean= args.mean,
new_std= args.std)
else: raise ValueError( "Undefined autoencoer model, check 'generator_arch' input argument.")
# Load comparator
if args.comparator_arch== "alexnet":
comparator= AlexNetComp( num_classes= args.num_classes,
output_layer= args.output_layer,
mean= args.mean,
std= args.std)
AlexNetComp_config( comparator= comparator.classifier,
load_comparator= args.load_comparator,
output_layer= args.comparator_layer,
strict= False)
else: raise ValueError( "Undefined comparator model, check 'comparator_arch' input argument.")
# Pass model to device
model.to( device)
model.classifier.eval()
model.generator.eval()
comparator.to( device)
comparator.eval()
# Model optimization criterion
pixel_crit= nn.L1Loss( reduction= args.reduction)
feat_crit= nn.MSELoss( reduction= args.reduction)
# One vs all classification
one_vs_all( in_test_loader= in_test_loader,
act_reference= args.act_reference,
comparator_layer= args.comparator_layer,
comparator= comparator,
feat_crit= feat_crit,
pixel_crit= pixel_crit,
variable= args.variable,
model= model,
optimizer= args.optimizer,
step_size= args.step_size,
sched_step= args.sched_step,
sched_gamma= args.sched_gamma,
iterations= args.iterations,
out_folder= args.out_folder,
feature_loss_weight= args.feature_loss_weight,
pixel_loss_weight= args.pixel_loss_weight,
randn_init= args.randn_init,
export_output= args.export_output,
batch_limit= args.batch_limit)
print( "Output directory: ", args.out_folder)