-
Notifications
You must be signed in to change notification settings - Fork 35
/
average_checkpoints_auto.py
executable file
·199 lines (147 loc) · 5.65 KB
/
average_checkpoints_auto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python
from __future__ import division
import onmt
import onmt.markdown
import torch
import argparse
import math
import numpy
import os, sys
from onmt.model_factory import build_model, build_language_model, build_classifier, optimize_model
from copy import deepcopy
from onmt.utils import checkpoint_paths, normalize_gradients
import glob
from onmt.constants import add_tokenidx
parser = argparse.ArgumentParser(description='translate.py')
onmt.markdown.add_md_help_argument(parser)
parser.add_argument('-models', required=True,
help='Path to model .pt file')
parser.add_argument('-type', default='seq2seq', help="""Type of models""")
parser.add_argument('-lm', action='store_true',
help='Language model (default is seq2seq model')
parser.add_argument('-sort_by_date', action='store_true',
help='Sort the model files by date')
parser.add_argument('-output', default='model.averaged',
help="""Path to output averaged model""")
parser.add_argument('-gpu', type=int, default=-1,
help="Device to run on")
parser.add_argument('-top', type=int, default=10,
help="Device to run on")
parser.add_argument('-method', default='mean',
help="method to average: mean|gmean")
def custom_build_model(opt, dict, lm=False, type='seq2seq', constants=None):
if type == 'seq2seq':
if not lm:
model = build_model(opt, dict, False, constants)
else:
model = build_language_model(opt, dict)
elif type == 'classifier':
model = build_classifier(opt, dict)
optimize_model(model)
return model
def main():
opt = parser.parse_args()
opt.cuda = opt.gpu > -1
if opt.cuda:
torch.cuda.set_device(opt.gpu)
path = opt.models
if not opt.sort_by_date:
existed_save_files = checkpoint_paths(path)
else:
existed_save_files = glob.glob(path + "/" + "*.pt")
existed_save_files.sort(key=os.path.getmtime)
print("\n".join(existed_save_files))
# print(existed_save_files)
models = existed_save_files
# take the top
models = models[:opt.top]
# print(models)
#
n_models = len(models)
#
# checkpoint for main model
checkpoint = torch.load(models[0], map_location=lambda storage, loc: storage)
if 'optim' in checkpoint:
del checkpoint['optim']
main_checkpoint = checkpoint
# best_checkpoint = {
# 'model': deepcpy(main_checkpoint['model']),
# 'dicts': main_checkpoint['dicts'],
# 'opt': main_checkpoint['opt'],
# 'epoch': -1,
# 'iteration': -1,
# 'batchOrder': None,
# 'optim': None
# }
best_checkpoint = main_checkpoint
# print("Saving best model to %s" % opt.output + ".top")
# torch.save(best_checkpoint, opt.output + ".top")
model_opt = checkpoint['opt']
dicts = checkpoint['dicts']
onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts)
constants = onmt.constants
# only create the object
model_opt.enc_state_dict = None
model_opt.dec_state_dict = None
print(model_opt.layers)
main_model = custom_build_model(model_opt, checkpoint['dicts'], lm=opt.lm, type=opt.type, constants=constants)
print("Loading main model from %s ..." % models[0])
try:
main_model.load_state_dict(checkpoint['model'])
except RuntimeError as e:
main_model.load_state_dict(checkpoint['model'], strict=True)
if opt.cuda:
main_model = main_model.cuda()
for i in range(1, len(models)):
model = models[i]
# checkpoint for models[i])
checkpoint = torch.load(model, map_location=lambda storage, loc: storage)
model_opt = checkpoint['opt']
# model_opt.enc_not_load_state = True
# model_opt.dec_not_load_state = True
model_opt.enc_state_dict = None
model_opt.dec_state_dict = None
# delete optim information to save GPU memory
if 'optim' in checkpoint:
del checkpoint['optim']
current_model = custom_build_model(model_opt, checkpoint['dicts'], lm=opt.lm, type=opt.type)
current_model.eval()
print("Loading model from %s ..." % models[i])
try:
current_model.load_state_dict(checkpoint['model'])
except RuntimeError as e:
current_model.load_state_dict(checkpoint['model'], strict=True)
if opt.cuda:
current_model = current_model.cuda()
if opt.method == 'mean':
# Sum the parameter values
for (main_param, param) in zip(main_model.parameters(), current_model.parameters()):
main_param.data.add_(param.data)
elif opt.method == 'gmean':
# Take the geometric mean of parameter values
for (main_param, param) in zip(main_model.parameters(), current_model.parameters()):
main_param.data.mul_(param.data)
else:
raise NotImplementedError
# Normalizing
if opt.method == 'mean':
for main_param in main_model.parameters():
main_param.data.div_(n_models)
elif opt.method == 'gmean':
for main_param in main_model.parameters():
main_param.data.pow_(1. / n_models)
# Saving
model_state_dict = main_model.state_dict()
save_checkpoint = {
'model': model_state_dict,
'dicts': dicts,
'opt': model_opt,
'epoch': -1,
'iteration': -1,
'batchOrder': None,
'optim': None
}
print("Saving averaged model to %s" % opt.output)
torch.save(save_checkpoint, opt.output)
if __name__ == "__main__":
main()