-
Notifications
You must be signed in to change notification settings - Fork 140
/
create_data.py
223 lines (200 loc) · 8.99 KB
/
create_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import argparse
import functools
import json
import os
import wave
from collections import Counter
from zhconv import convert
import numpy as np
import soundfile
from tqdm import tqdm
from data_utils.normalizer import FeatureNormalizer
from utils.utility import add_arguments, print_arguments, read_manifest, change_rate
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('annotation_path', str, 'dataset/annotation/', '标注文件的路径,如果annotation_path包含了test.txt,就全部使用test.txt的数据作为测试数据')
add_arg('manifest_prefix', str, 'dataset/', '训练数据清单,包括音频路径和标注信息')
add_arg('is_change_frame_rate', bool, True, '是否统一改变音频为16000Hz,这会消耗大量的时间')
add_arg('max_test_manifest', int, 10000, '最大的测试数据数量')
add_arg('count_threshold', int, 2, '字符计数的截断阈值,0为不做限制')
add_arg('vocab_path', str, 'dataset/zh_vocab.txt', '生成的数据字典文件')
add_arg('num_workers', int, 8, '读取数据的线程数量')
add_arg('manifest_paths', str, 'dataset/manifest.train', '数据列表路径')
add_arg('num_samples', int, 1000000, '用于计算均值和标准值得音频数量,当为-1使用全部数据')
add_arg('output_path', str, './dataset/mean_std.npz', '保存均值和标准值得numpy文件路径,后缀 (.npz).')
args = parser.parse_args()
# 创建数据列表
def create_manifest(annotation_path, manifest_path_prefix):
data_list = []
test_list = []
durations_all = []
duration_0_10 = 0
duration_10_20 = 0
duration_20 = 0
# 获取全部的标注文件
for annotation_text in os.listdir(annotation_path):
durations = []
print('正在创建%s的数据列表,请等待 ...' % annotation_text)
annotation_text_path = os.path.join(annotation_path, annotation_text)
# 读取标注文件
with open(annotation_text_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in tqdm(lines):
audio_path = line.split('\t')[0]
try:
# 过滤非法的字符
text = is_ustr(line.split('\t')[1].replace('\n', '').replace('\r', ''))
# 保证全部都是简体
text = convert(text, 'zh-cn')
# 重新调整音频格式并保存
if args.is_change_frame_rate:
change_rate(audio_path)
# 获取音频的长度
f_wave = wave.open(audio_path, "rb")
duration = f_wave.getnframes() / f_wave.getframerate()
if duration <= 10:
duration_0_10 += 1
elif 10 < duration <= 20:
duration_10_20 += 1
else:
duration_20 += 1
durations.append(duration)
d = json.dumps(
{
'audio_filepath': audio_path.replace('\\', '/'),
'duration': duration,
'text': text
},
ensure_ascii=False)
if annotation_text == 'test.txt':
test_list.append(d)
else:
data_list.append(d)
except Exception as e:
print(e)
continue
durations_all.append(sum(durations))
print("%s数据一共[%d]小时!" % (annotation_text, int(sum(durations) / 3600)))
print("0-10秒的数量:%d,10-20秒的数量:%d,大于20秒的数量:%d" % (duration_0_10, duration_10_20, duration_20))
# 将音频的路径,长度和标签写入到数据列表中
f_train = open(os.path.join(manifest_path_prefix, 'manifest.train'), 'w', encoding='utf-8')
f_test = open(os.path.join(manifest_path_prefix, 'manifest.test'), 'w', encoding='utf-8')
for line in test_list:
f_test.write(line + '\n')
interval = 500
if len(data_list) / 500 > args.max_test_manifest:
interval = len(data_list) // args.max_test_manifest
for i, line in enumerate(data_list):
if i % interval == 0 and i != 0:
if len(test_list) == 0:
f_test.write(line + '\n')
else:
f_train.write(line + '\n')
else:
f_train.write(line + '\n')
f_train.close()
f_test.close()
print("创建数量列表完成,全部数据一共[%d]小时!" % int(sum(durations_all) / 3600))
# 过滤非文字的字符
def is_ustr(in_str):
out_str = ''
for i in range(len(in_str)):
if is_uchar(in_str[i]):
out_str = out_str + in_str[i]
else:
out_str = out_str + ' '
return ''.join(out_str.split())
# 判断是否为文字字符
def is_uchar(uchar):
if u'\u4e00' <= uchar <= u'\u9fa5':
return True
if u'\u0030' <= uchar <= u'\u0039':
return False
if (u'\u0041' <= uchar <= u'\u005a') or (u'\u0061' <= uchar <= u'\u007a'):
return False
if uchar in ('-', ',', '.', '>', '?'):
return False
return False
# 生成噪声的数据列表
def create_noise(path='dataset/audio/noise', min_duration=30):
if not os.path.exists(path):
print('噪声音频文件为空,已跳过!')
return
json_lines = []
print('正在创建噪声数据列表,路径:%s,请等待 ...' % path)
for file in tqdm(os.listdir(path)):
audio_path = os.path.join(path, file)
try:
# 噪声的标签可以标记为空
text = ""
# 重新调整音频格式并保存
if args.is_change_frame_rate:
change_rate(audio_path)
f_wave = wave.open(audio_path, "rb")
duration = f_wave.getnframes() / f_wave.getframerate()
# 拼接音频
if duration < min_duration:
wav = soundfile.read(audio_path)[0]
data = wav
for i in range(int(min_duration / duration) + 1):
data = np.hstack([data, wav])
soundfile.write(audio_path, data, samplerate=16000)
f_wave = wave.open(audio_path, "rb")
duration = f_wave.getnframes() / f_wave.getframerate()
json_lines.append(
json.dumps(
{
'audio_filepath': audio_path.replace('\\', '/'),
'duration': duration,
'text': text
},
ensure_ascii=False))
except Exception as e:
continue
with open(os.path.join(args.manifest_prefix, 'manifest.noise'), 'w', encoding='utf-8') as f_noise:
for json_line in json_lines:
f_noise.write(json_line + '\n')
# 获取全部字符
def count_manifest(counter, manifest_path):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
for char in line_json['text']:
counter.update(char)
# 计算数据集的均值和标准值
def compute_mean_std(manifest_path, num_samples, output_path):
# 随机取指定的数量计算平均值归一化
normalizer = FeatureNormalizer(mean_std_filepath=None,
manifest_path=manifest_path,
num_samples=num_samples,
num_workers=args.num_workers)
# 将计算的结果保存的文件中
normalizer.write_to_file(output_path)
print('计算的均值和标准值已保存在 %s!' % output_path)
def main():
print_arguments(args)
print('开始生成数据列表...')
create_manifest(annotation_path=args.annotation_path,
manifest_path_prefix=args.manifest_prefix)
print('='*70)
print('开始生成噪声数据列表...')
create_noise(path='dataset/audio/noise')
print('='*70)
print('开始生成数据字典...')
counter = Counter()
# 获取全部数据列表中的标签字符
count_manifest(counter, args.manifest_paths)
# 为每一个字符都生成一个ID
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
with open(args.vocab_path, 'w', encoding='utf-8') as fout:
fout.write('<blank>\t-1\n')
for char, count in count_sorted:
# 跳过指定的字符阈值,超过这大小的字符都忽略
if count < args.count_threshold: break
fout.write('%s\t%d\n' % (char, count))
print('数据词汇表已生成完成,保存与:%s' % args.vocab_path)
print('='*70)
print('开始抽取%s条数据计算均值和标准值...' % args.num_samples)
compute_mean_std(args.manifest_paths, args.num_samples, args.output_path)
print('='*70)
if __name__ == '__main__':
main()