-
Notifications
You must be signed in to change notification settings - Fork 35
/
utils.py
32 lines (24 loc) · 935 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
## utils.py -- utility functions
##
## Copyright (C) 2017, Dongyu Meng <zbshfmmm@gmail.com>.
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
import pickle
import os
import numpy as np
def prepare_data(dataset, idx):
"""
Extract data from index.
dataset: Full, working dataset. Such as MNIST().
idx: Index of test examples that we care about.
return: X, targets, Y
"""
return dataset.test_data[idx], dataset.test_labels[idx], np.argmax(dataset.test_labels[idx], axis=1)
def save_obj(obj, name, directory='./attack_data/'):
with open(os.path.join(directory, name + '.pkl'), 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name, directory='./attack_data/'):
if name.endswith(".pkl"): name = name[:-4]
with open(os.path.join(directory, name + '.pkl'), 'rb') as f:
return pickle.load(f)