/
io_utils.py
180 lines (165 loc) · 6.62 KB
/
io_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import numpy as np
import os
import glob
import argparse
import backbone
model_dict = dict(
Conv4=backbone.Conv4,
Conv4S=backbone.Conv4S,
Conv6=backbone.Conv6,
ResNet10=backbone.ResNet10,
ResNet18=backbone.ResNet18,
ResNet34=backbone.ResNet34,
ResNet50=backbone.ResNet50,
ResNet101=backbone.ResNet101,
)
def parse_args(script):
parser = argparse.ArgumentParser(description="few-shot script %s" % (script))
parser.add_argument(
"--seed",
default=0,
type=int,
help="Seed for Numpy and pyTorch. Default: 0 (None)",
)
parser.add_argument(
"--dataset", default="CUB", help="CUB/miniImagenet/cross/omniglot/cross_char"
)
parser.add_argument(
"--model", default="Conv4", help="model: Conv{4|6} / ResNet{10|18|34|50|101}"
) # 50 and 101 are not used in the paper
parser.add_argument(
"--method",
default="baseline",
help="baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}",
) # relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
parser.add_argument(
"--train_n_way", default=5, type=int, help="class num to classify for training"
) # baseline and baseline++ would ignore this parameter
parser.add_argument(
"--test_n_way",
default=5,
type=int,
help="class num to classify for testing (validation) ",
) # baseline and baseline++ only use this parameter in finetuning
parser.add_argument(
"--n_shot",
default=5,
type=int,
help="number of labeled data in each class, same as n_support",
) # baseline and baseline++ only use this parameter in finetuning
parser.add_argument(
"--train_aug",
action="store_true",
help="perform data augmentation or not during training ",
) # still required for save_features.py and test.py to find the model path correctly
if script == "train":
parser.add_argument(
"--num_classes",
default=200,
type=int,
help="total number of classes in softmax, only used in baseline",
) # make it larger than the maximum label value in base class
parser.add_argument("--save_freq", default=50, type=int, help="Save frequency")
parser.add_argument("--start_epoch", default=0, type=int, help="Starting epoch")
parser.add_argument(
"--stop_epoch", default=-1, type=int, help="Stopping epoch"
) # for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
parser.add_argument(
"--resume",
action="store_true",
help="continue from previous trained model with largest epoch",
)
parser.add_argument(
"--warmup",
action="store_true",
help="continue from baseline, neglected if resume is true",
) # never used in the paper
elif script == "save_features":
parser.add_argument(
"--split", default="novel", help="base/val/novel"
) # default novel, but you can also test base/val class accuracy if you want
parser.add_argument(
"--save_iter",
default=-1,
type=int,
help="save feature from the model trained in x epoch, use the best model if x is -1",
)
elif script == "test":
parser.add_argument(
"--split", default="novel", help="base/val/novel"
) # default novel, but you can also test base/val class accuracy if you want
parser.add_argument(
"--save_iter",
default=-1,
type=int,
help="saved feature from the model trained in x epoch, use the best model if x is -1",
)
parser.add_argument(
"--adaptation",
action="store_true",
help="further adaptation in test time or not",
)
parser.add_argument(
"--repeat",
default=5,
type=int,
help="Repeat the test N times with different seeds and take the mean. The seeds range is [seed, seed+repeat]",
)
else:
raise ValueError("Unknown script")
return parser.parse_args()
def parse_args_regression(script):
parser = argparse.ArgumentParser(description="few-shot script %s" % (script))
parser.add_argument(
"--seed",
default=0,
type=int,
help="Seed for Numpy and pyTorch. Default: 0 (None)",
)
parser.add_argument("--model", default="Conv3", help="model: Conv{3} / MLP{2}")
parser.add_argument("--method", default="gpnet", help="gpnet / transfer")
parser.add_argument("--dataset", default="QMUL", help="QMUL / sines")
parser.add_argument(
"--spectral",
action="store_true",
help="Use a spectral covariance kernel function",
)
if script == "train_regression":
parser.add_argument("--start_epoch", default=0, type=int, help="Starting epoch")
parser.add_argument(
"--stop_epoch", default=100, type=int, help="Stopping epoch"
) # for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
parser.add_argument(
"--resume",
action="store_true",
help="continue from previous trained model with largest epoch",
)
elif script == "test_regression":
parser.add_argument(
"--n_support",
default=5,
type=int,
help="Number of points on trajectory to be given as support points",
)
parser.add_argument(
"--n_test_epochs", default=10, type=int, help="How many test people?"
)
return parser.parse_args()
def get_assigned_file(checkpoint_dir, num):
assign_file = os.path.join(checkpoint_dir, "{:d}.pth".format(num))
return assign_file
def get_resume_file(checkpoint_dir):
filelist = glob.glob(os.path.join(checkpoint_dir, "*.pth"))
if len(filelist) == 0:
return None
filelist = [x for x in filelist if os.path.basename(x) != "best_model.pth"]
epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist])
max_epoch = np.max(epochs)
resume_file = os.path.join(checkpoint_dir, "{:d}.pth".format(max_epoch))
return resume_file
def get_best_file(checkpoint_dir):
best_file = os.path.join(checkpoint_dir, "best_model.pth")
if os.path.isfile(best_file):
return best_file
else:
return get_resume_file(checkpoint_dir)