-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
executable file
·62 lines (53 loc) · 2.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from data_generator import *
import numpy as np
import time
def make_dir(path):
if not os.path.isdir(path):
print(' Create dir:{}...'.format(path))
os.mkdir(path)
def gen_save_data(start=-2, end=-1, stride=2):
dB = range(start, end, stride)
smoothingLen = 11
chL = 10
img_rows, img_cols = 2, 12
for db in dB:
print("situation is: {}dB".format(db))
begin_db = time.time()
# X_train, Y_train, X_test, Y_test = generateData(100000000, 90000000, db, smoothingLen, chL, 'cnn')
X_train, Y_train, X_test, Y_test = generateData(10000000, 9000000, db, smoothingLen, chL, 'cnn')
X_train_crnn = np.asarray(X_train)
Y_train_crnn = np.asarray(Y_train)
X_test_crnn = np.asarray(X_test)
Y_test_crnn = np.asarray(Y_test)
# print("situation is: {}dB".format(db))
# print("X_train_crnn is: ", X_train_crnn.shape)
# print("Y_train_crnn is: ", Y_train_crnn.shape)
# print("X_test_crnn is: ", X_test_crnn.shape)
# print("Y_test_crnn is: ", Y_test_crnn.shape)
X_train_crnn = X_train_crnn.reshape(X_train_crnn.shape[0], 1, img_rows, img_cols)
X_test_crnn = X_test_crnn.reshape(X_test_crnn.shape[0], 1, img_rows, img_cols)
# print("==========after shape=========")
# print("X_train_crnn is: ", X_train_crnn.shape)
# print("Y_train_crnn is: ", Y_train_crnn.shape)
# print("X_test_crnn is: ", X_test_crnn.shape)
# print("Y_test_crnn is: ", Y_test_crnn.shape)
make_dir("./database_nonlinear/")
np.save("./database_nonlinear/{}_train_x.npy".format(db), X_train_crnn) #读取文件
np.save("./database_nonlinear/{}_train_y.npy".format(db), Y_train_crnn) #读取文件
np.save("./database_nonlinear/{}_test_x.npy".format(db), X_test_crnn) #读取文件
np.save("./database_nonlinear/{}_test_y.npy".format(db), Y_test_crnn) #读取文件
# make_dir("./database0/")
# np.save("./database0/{}_train_x.npy".format(db), X_train_crnn) #读取文件
# np.save("./database0/{}_train_y.npy".format(db), Y_train_crnn) #读取文件
# np.save("./database0/{}_test_x.npy".format(db), X_test_crnn) #读取文件
# np.save("./database0/{}_test_y.npy".format(db), Y_test_crnn) #读取文件
print("Data Saved in {}dB".format(db))
del X_train, Y_train, X_test, Y_test
end_db = time.time()
print(db, 'db time', (end_db-begin_db)/60, 'min')
if __name__ == "__main__":
begin = time.time()
gen_save_data(start=32, end=51, stride=2)
end = time.time()
print('total time:', (end-begin)/60, 'min')