-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
64 lines (52 loc) · 1.75 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import yaml
import torch
import logging
import os
import argparse
import shutil
import model_architecture
import train
import predict
# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--path_yaml',
default='params.yaml',
)
args = parser.parse_args()
# setup logger - TODO: right format and save log in output folder
logger = logging.getLogger('main')
def main():
# load params
with open(args.path_yaml, 'r') as stream:
params = yaml.safe_load(stream)
# create output folder
os.makedirs(params['path_folder_out'], exist_ok=True)
shutil.copyfile(args.path_yaml, os.path.join(params['path_folder_out'], 'params.yaml'))
# choose device for loading
if torch.cuda.is_available():
device = torch.device("cuda")
else:
print("WARNING !!! Using CPU ")
device = torch.device("cpu")
logger.info('Working on device={}'.format(device))
# load model
num_classes = 8
model = model_architecture.MyUNet(num_classes, device, params).to(device)
path_weights = params['model']['path_weights']
if path_weights:
assert os.path.isfile(path_weights), "path_weights does not exist as a file"
model.load_state_dict(torch.load(path_weights, map_location=device))
# execute training or inference
if params['mode'] == 'train':
df_out = train.train(model,
device,
params,
)
elif params['mode'] == 'predict':
df_out = predict.predict(model,
device,
params,
)
logger.info("=== Finished")
if __name__ == '__main__':
main()