-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
68 lines (59 loc) · 1.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""training script."""
from argparse import ArgumentParser
from types import SimpleNamespace
import yaml
import wandb
from detect_to_track.models import DetectTrackModule
from detect_to_track.data.imagenet import setup_vid_datasets
from detect_to_track.trainer import DetectTrackTrainer
parser = ArgumentParser(__doc__)
parser.add_argument("-c", "--cfg", default="cfg/default.yaml", help="path to cfg file")
args = parser.parse_args()
with open(args.cfg) as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
wandb.init(config=cfg)
cfg = SimpleNamespace(**yaml.load(open(args.cfg), Loader=yaml.FullLoader))
model = DetectTrackModule(
cfg.BACKBONE_ARCH,
cfg.FIRST_TRAINABLE_STAGE,
len(cfg.ANCHOR_AREAS) * len(cfg.ANCHOR_ASPECT_RATIOS),
cfg.N_CLASSES,
cfg.K,
cfg.D_MAX,
cfg.K,
)
trn_manager, val_manager, rep_manager = setup_vid_datasets(
cfg.DATA_ROOT,
cfg.VID_PARTITION_SIZES,
cfg.TRN_SIZE,
cfg.VAL_SIZE,
cfg.REP_SIZE,
cfg.P_DET,
cfg.A,
)
trainer = DetectTrackTrainer(
model,
trn_manager,
val_manager,
rep_manager,
cfg.BATCH_SIZE,
cfg.INPUT_SHAPE,
cfg.FM_STRIDE,
cfg.ANCHOR_AREAS,
cfg.ANCHOR_ASPECT_RATIOS,
cfg.ENCODER_IOU_THRESH,
cfg.ENCODER_IOU_MARGIN,
cfg.TRAIN_ROI_CONF_THRESH,
cfg.TRAIN_MAX_ROIS,
cfg.TRAIN_NMS_IOU_THRESH,
cfg.ALPHA,
cfg.GAMMA,
cfg.COEFS,
cfg.SGD_KWARGS,
cfg.PATIENCE,
cfg.EVAL_ROI_CONF_THRESH,
cfg.EVAL_MAX_ROIS,
cfg.EVAL_NMS_IOU_THRESH,
cfg.EVAL_RCNN_CONF_THRESH,
)
trainer.run()