-
Notifications
You must be signed in to change notification settings - Fork 3
/
settings.py
executable file
·81 lines (60 loc) · 1.52 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Hyperparameter and training settings for all the datasets
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-data', nargs=1, type=str, default=['mnist'])
parser.add_argument('-mode', nargs=1, type=str, default=['test'])
parser.add_argument('-model_file', nargs=1, type=str, default=['saved_models/MNIST/MNIST.pth'])
parser.add_argument('-expl', nargs=1, type=bool, default=[False])
args = parser.parse_args()
data_name = args.data[0]
mode = args.mode[0]
model_file = args.model_file[0]
expl = args.expl[0]
data_path = 'Data/'
coefs = {
'crs_ent': 1,
'recon': 1,
'kl': 1,
'ortho': 1,
}
if (data_name == "mnist"):
img_size = 28
latent = 256
num_prototypes = 50
num_classes = 10
batch_size = 128
lr = 1e-3
num_train_epochs = 10
elif (data_name == "fmnist"):
img_size = 28
latent = 256
num_prototypes = 100
num_classes = 10
batch_size = 128
lr = 1e-3
num_train_epochs = 10
if (data_name == "cifar10"):
img_size = 32
latent = 512
num_prototypes = 100
num_classes = 10
batch_size = 128
lr = 1e-3
num_train_epochs = 36
if (data_name == "svhn"):
img_size = 32
latent = 512
num_prototypes = 50
num_classes = 10
batch_size = 64
lr = 1e-3
num_train_epochs = 36
if (data_name == "quickdraw"):
img_size = 28
latent = 512
num_prototypes = 100
num_classes = 10
batch_size = 128
lr = 1e-3
num_train_epochs = 10
data_path = data_path + 'quickdraw/'