This repository has been archived by the owner on Aug 31, 2021. It is now read-only.
/
assemble_models.py
68 lines (57 loc) · 2.31 KB
/
assemble_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
"""This script is useful to assemble F2Fi models into a F2F model.
NB: assumes multiscale architecture_f2fi and parallel architecture,
etc. see modelsConfig below"""
import os
import torch
import sys
from mypython.logger import create_logger
from torch.nn.parameter import Parameter
from mytorch.implementation_utils import get_nb_parameters
#-------------------------------------------------------------------------------
from config import make_config
opt, configs, checkpoint = make_config()
assert opt['architecture'] == 'parallel'
assert opt['architecture_f2fi'] == 'multiscale'
#-------------------------------------------------------------------------------
# Start logging
logger = create_logger(os.path.join(opt['logs'], 'assemble_models.log'))
logger.info('============ Initialized logger ============')
f2f_paths = {
'f2f5': os.environ['F2F5_WEIGHTS'],
'f2f4': os.environ['F2F4_WEIGHTS'],
'f2f3': os.environ['F2F3_WEIGHTS'],
'f2f2': os.environ['F2F2_WEIGHTS']
}
output_name = 'F2F_model.net'
#-------------------------------------------------------------------------------
# Create model
modelsConfig = configs['models']
assert opt['architecture_f2fi'] == 'multiscale'
from models import F2F_multiscale as F2F
model = F2F(modelsConfig)
logger.info('This model has %d parameters.' % (get_nb_parameters(model)))
logger.info('F2F5 has %d parameters.' % (get_nb_parameters(model.f2f5)))
logger.info('F2F4 has %d parameters.' % (get_nb_parameters(model.f2f4)))
logger.info('F2F3 has %d parameters.' % (get_nb_parameters(model.f2f3)))
logger.info('F2F2 has %d parameters.' % (get_nb_parameters(model.f2f2)))
submodels = {
'f2f5': model.f2f5,
'f2f4': model.f2f4,
'f2f3': model.f2f3,
'f2f2': model.f2f2
}
for lvl in f2f_paths:
params = torch.load(f2f_paths[lvl])
params = {'.'.join(k.split('.')[1:]) : v for k,v in params.items()}
submodels[lvl].load_state_dict(params)
#-------------------------------------------------------------------------------
# Save model
output_path = os.path.join(opt['save'], output_name)
torch.save(model.state_dict(), output_path)
logger.info('Model saved in %s' % output_path)