/
launch.py
141 lines (121 loc) · 4.21 KB
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""
for launch locally or in flow
"""
import argparse
import os
import sys
import pprint
import PIL
from collections import defaultdict
from tabulate import tabulate
from typing import Tuple
import torch
from knn.utils.file_io import PathManager
from knn.utils import logging
from knn.utils.distributed import get_rank, get_world_size
def collect_torch_env() -> str:
try:
import torch.__config__
return torch.__config__.show()
except ImportError:
# compatible with older versions of pytorch
from torch.utils.collect_env import get_pretty_env_info
return get_pretty_env_info()
def get_env_module() -> Tuple[str]:
var_name = "ENV_MODULE"
return var_name, os.environ.get(var_name, "<not set>")
def collect_env_info() -> str:
data = []
data.append(("Python", sys.version.replace("\n", "")))
data.append(get_env_module())
data.append(("PyTorch", torch.__version__))
data.append(("PyTorch Debug Build", torch.version.debug))
has_cuda = torch.cuda.is_available()
data.append(("CUDA available", has_cuda))
if has_cuda:
data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"]))
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
for name, devids in devices.items():
data.append(("GPU " + ",".join(devids), name))
data.append(("Pillow", PIL.__version__))
try:
import cv2
data.append(("cv2", cv2.__version__))
except ImportError:
pass
env_str = tabulate(data) + "\n"
env_str += collect_torch_env()
return env_str
def default_argument_parser():
"""
create a parser
"""
parser = argparse.ArgumentParser(description="kNN-revisited")
parser.add_argument(
"--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_false",
help="whether to attempt to resume from the checkpoint directory",
default=True # always true for jobs in cluster
)
parser.add_argument(
"--eval-only", action="store_true", help="perform evaluation only")
# parser.add_argument(
# "--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument(
"--pretrain", action="store_true", help="perform pretraining instead of downstream finetuning tasks")
parser.add_argument(
"--train-type", default="", help="training types")
# parser.add_argument(
# "--init_method",
# help="Initialization method, includes TCP or shared file-system",
# default="tcp://localhost:9999",
# type=str,
# )
parser.add_argument(
"--shard_id",
default="0",
type=int,
)
parser.add_argument(
"--num_shards",
default="1",
type=int,
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def train_setup(args, cfg) -> None:
output_dir = cfg.OUTPUT_DIR
if output_dir:
PathManager.mkdirs(output_dir)
logger = logging.setup_logging(
cfg.NUM_GPUS, get_world_size(), output_dir, name="nearest_neighbors")
# Log basic information about environment, cmdline arguments, and config
rank = get_rank()
logger.info(
f"Rank of current process: {rank}. World size: {get_world_size()}")
logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file,
PathManager.open(args.config_file, "r").read()
)
)
# Show the config
logger.info("Training with config:")
logger.info(pprint.pformat(cfg))
# cudnn benchmark has large overhead.
# It shouldn't be used considering the small size of typical val set.
if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK