-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
executable file
·89 lines (67 loc) · 2.54 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# encoding: utf-8
'''
@author: ZiqiLiu
@file: args.py
@time: 2017/6/19 下午3:35
@desc:
'''
import argparse
import os
from config import attention_config, rnn_config
def config_value_cast(config, key, value):
_type = type(getattr(config, key))
if _type is bool:
return value.lower() not in ['false', 'f', 'no', 'n', 'off', '0']
else:
return type(getattr(config, key))(value)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', help='train: train model, ' +
'valid: model validation, ',
default=None)
parser.add_argument('--model', help='rnn or attention',
default=None)
parser.add_argument('--ktq',
help='whether run in ktq', type=int,
default=0)
parser.add_argument('-g', '--gpu', help='which gpu to use')
parser.add_argument('--train_path', help='train data path',
default='/ssd/keyword/ctc_23w/train/')
parser.add_argument('--valid_path', help='valid data path',
default='/ssd/keyword/ctc_23w/valid/')
parser.add_argument('--noise_path', help='noise data path',
default='/ssd/keyword/ctc_23w/noise/')
parser.add_argument('-o', '--override', nargs='*', default=[],
help='Override configuration, with k-v pairs')
flags = parser.parse_args().__dict__
return flags
flags = get_args()
model = flags['model']
if model == 'rnn':
config = rnn_config.get_config()
elif model == 'attention':
config = attention_config.get_config()
else:
raise Exception('model %s not defined!' % model)
if not flags['ktq']:
os.environ["CUDA_VISIBLE_DEVICES"] = str(flags['gpu'])
print(flags)
def parse_args():
mode = flags['mode']
setattr(config, 'mode', mode)
L = len(flags['override'])
assert L % 2 == 0
for i in range(L // 2):
key, value = flags['override'][2 * i], flags['override'][2 * i + 1]
if not hasattr(config, key):
print("WARNING: Invalid override with attribute %s" % (key))
else:
setattr(config, key, config_value_cast(config, key, value))
for key in ['train_path', 'valid_path', 'noise_path']:
if not hasattr(config, key):
print("WARNING: Invalid override with attribute %s" % (key))
else:
setattr(config, key, config_value_cast(config, key, flags[key]))
return config, model
if __name__ == '__main__':
parse_args()