/
Utilities.py
144 lines (105 loc) · 4.32 KB
/
Utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import pandas as pd
from skimage.transform import warp
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import pdb
def parseLogFile(logFile):
log = open(logFile)
data = {"Model Architecture" : None,
"Poison Index" : None,
"K Value" : None,
"Class Balance" : None,
"Replicate Imbalance" : None,
"True Positive" : None,
"True Negative" : None,
"False Positive" : None,
"False Negative" : None,
"Matthews Correlation Coefficient" : None,
"Train Accuracy" : None,
"Test Accuracy" : None,
"Poison Success on Target Image" : None,
"Status" : None}
for item in log:
item = item.strip("\n")
if "Model Architecture" in item:
data["Model Architecture"] = item.replace("Model Architecture: ", "")
elif "Poison Index" in item:
data["Poison Index"] = int(item.replace("Poison Index: ", ""))
elif "K Value" in item:
data["K Value"] = int(item.replace("K Value: ", ""),)
elif "classBalance" in item:
data["Class Balance"] = item.replace("classBalance: ", "")
elif "Replicate Imbalance" in item:
data["Replicate Imbalance"] = True if item.replace("Replicate Imbalance: ", "") == "True" else False
elif "|" in item:
metrics = item.split("|")
data["True Positive"] = int(metrics[0].replace("True Positive: ", ""))
data["True Negative"] = int(metrics[1].replace("True Negative: ", ""))
data["False Positive"] = int(metrics[2].replace("False Positive: ", ""))
data["False Negative"] = int(metrics[3].replace("False Negative: ", ""))
elif "Train Accuracy" in item:
data["Train Accuracy"] = float(item.replace("Train Accuracy: ", ""))
elif "Test Accuracy" in item:
data["Test Accuracy"] = float(item.replace("Test Accuracy: ", ""))
elif "Poison Success on Target Image" in item:
data["Poison Success on Target Image"] = True if item.replace("Poison Success on Target Image: ", "") == "True" else False
elif "Matthews Correlation Coefficient" in item:
data["Matthews Correlation Coefficient"] = float(item.replace("Matthews Correlation Coefficient: ", ""))
elif "Status" in item:
data["Status"] = item.replace("Status: ", "")
return data
def writeLog(logFile, line):
log = open(logFile, "a")
log.write(line + "\n")
log.close()
def clearLog(logFile):
log = open(logFile, "w")
log.close()
def parsePoisonIndex(fileLocation):
file = open(fileLocation)
allPoisonIndex = []
for line in file:
allPoisonIndex.append(int(line.strip("\n")))
return allPoisonIndex
def parseTargetIndex(fileLocation, poisonIndex):
file = open(fileLocation)
for line in file:
imgLocation, targetID = line.split()
if int(targetID) == poisonIndex:
return imgLocation
def testTarget(model, device, target, targetClass):
model.eval()
classification = model(target.unsqueeze(0).to(device))
index = torch.argmax(classification).item()
if index == targetClass:
return True
return False
def featureExtraction(model, device, testData):
model.eval()
allFeatureVector = []
allID = []
allFileNames = []
with torch.no_grad():
for img, ID, fileName in tqdm(testData):
img = img.to(device)
featureVector = model.module.penultimate(img)
allFeatureVector = allFeatureVector + [vector for vector in featureVector]
allID = allID + [i for i in ID]
allFileNames = allFileNames + [file for file in fileName]
return allFeatureVector, allID, allFileNames
def classificationAccuracy(model, device, testData):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for img, ID in tqdm(testData):
img = img.to(device)
ID = ID.to(device)
predictedClassification = model(img)
predictedClass = torch.max(predictedClassification, dim=1)[1]
correct = correct + torch.sum(predictedClass == ID)
total = total + predictedClassification.shape[0]
return correct.item() / float(total)