-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
38 lines (30 loc) · 1.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import builtins
import torch
import torch.distributed as dist
import torch.utils.data
import torch.utils.data.distributed
from train_dg import train_dg_seg_network
from train_dg_2d import train_dg_2d_seg_network
def train_worker(gpu, ngpus_per_node, config, args):
args.gpu = gpu
# suppress printing if not master
if args.multiprocessing_distributed and args.gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
elif config.DATASET.NAME in ['optic']:
train_dg_seg_network(gpu, ngpus_per_node, config, args)
elif config.DATASET.NAME in ['rvs']:
train_dg_2d_seg_network(gpu, ngpus_per_node, config, args)