-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_greedy.py
62 lines (51 loc) · 1.73 KB
/
main_greedy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import math
import numpy as np
import pandas as pd
from knn_robustness.utils import initialize_params
from knn_robustness.utils import initialize_data
from knn_robustness.knn import GreedyAttack
from knn_robustness.knn import SubsolverFactory
params = initialize_params('greedy')
X_train, y_train, X_test, y_test = initialize_data(params)
attack = GreedyAttack(
X_train=X_train,
y_train=y_train,
n_neighbors=params.getint('n_neighbors'),
subsolver=SubsolverFactory().create(params.get('subsolver')),
n_far=params.getint('n_far'),
max_trials=params.getint('max_trials'),
min_trials=params.getint('min_trials')
)
count = 0
success_notes = []
perturbation_norms = []
for instance, label in zip(X_test, y_test):
if attack.predict_individual(instance) != label:
continue
perturbation = attack(instance)
if perturbation is None:
success = False
perturbation_norm = math.inf
else:
success = True
perturbation_norm = np.linalg.norm(perturbation)
success_notes.append(success)
perturbation_norms.append(perturbation_norm)
details = pd.DataFrame({
'success': success_notes,
'perturbation': perturbation_norms
})
details.to_csv(os.path.join(params.get('result_dir'), 'detail.csv'))
count += 1
print(f'{count:03d} {success} {perturbation_norm:.7f}')
if count >= params.getint('n_evaluate'):
break
summary = pd.DataFrame({
'num': [count],
'success_rate': [details['success'].sum()/count],
'mean': [details['perturbation'][details['success']].mean()],
'median': [details['perturbation'][details['success']].median()]
})
summary.to_csv(os.path.join(params.get('result_dir'), 'summary.csv'))
print(summary)