-
Notifications
You must be signed in to change notification settings - Fork 7
/
mit_utils.py
156 lines (134 loc) · 4.94 KB
/
mit_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 14 23:47:38 2019
@author: Winham
辅助函数
"""
import warnings
import numpy as np
from scipy.signal import resample
# import pywt
from sklearn.preprocessing import scale
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.utils.multiclass import unique_labels
import matplotlib.pyplot as plt
# ===========================================
warnings.filterwarnings("ignore")
import torch
import numpy as np
import time,os
from sklearn.metrics import f1_score
from torch import nn
def mkdirs(path):
if not os.path.exists(path):
os.makedirs(path)
def calc_f1(y_true, y_pre, threshold=0.5):
y_true = y_true.view(-1).cpu().detach().numpy().astype(np.int)
y_pre = y_pre.cpu().detach().numpy()
y_pre = np.argmax(y_pre, axis=-1)
return f1_score(y_true, y_pre, average='macro')
def print_time_cost(since):
time_elapsed = time.time() - since
return '{:.0f}m{:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)
def adjust_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
class WeightedMultilabel(nn.Module):
def __init__(self, weights: torch.Tensor):
super(WeightedMultilabel, self).__init__()
self.cerition = nn.BCEWithLogitsLoss(reduction='none')
self.weights = weights
def forward(self, outputs, targets):
loss = self.cerition(outputs, targets)
return (loss * self.weights).mean()
# =======================================
def sig_wt_filt(sig):
"""
对信号进行小波变换滤波
:param sig: 输入信号,1-d array
:return: 小波滤波后的信号,1-d array
"""
coeffs = pywt.wavedec(sig, 'db6', level=9)
coeffs[-1] = np.zeros(len(coeffs[-1]))
coeffs[-2] = np.zeros(len(coeffs[-2]))
coeffs[0] = np.zeros(len(coeffs[0]))
sig_filt = pywt.waverec(coeffs, 'db6')
return sig_filt
def multi_prep(sig, target_point_num=1280):
"""
信号预处理
:param sig: 原始信号,1-d array
:param target_point_num: 信号目标长度,int
:return: 重采样并z-score标准化后的信号,1-d array
"""
assert len(sig.shape) == 2, 'Not for 1-D data.Use 2-D data.'
sig = resample(sig, target_point_num, axis=1)
for i in range(sig.shape[0]):
sig[i] = sig_wt_filt(sig[i])
sig = scale(sig, axis=1)
return sig
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
绘制混淆矩阵图,来源:
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
"""
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
cm = confusion_matrix(y_true, y_pred)
classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
# fig, ax = plt.subplots()
# # for i in range(5):
# # cm[i,i] = 0
# im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
# ax.figure.colorbar(im, ax=ax)
# ax.set(xticks=np.arange(cm.shape[1]),
# yticks=np.arange(cm.shape[0]),
# xticklabels=classes, yticklabels=classes,
# title=title,
# ylabel='True label',
# xlabel='Predicted label')
#
# plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
# rotation_mode="anchor")
#
# fmt = '.2f' if normalize else 'd'
# thresh = cm.max() / 2.
# for i in range(cm.shape[0]):
# for j in range(cm.shape[1]):
# ax.text(j, i, format(cm[i, j], fmt),
# ha="center", va="center",
# color="white" if cm[i, j] > thresh else "black")
# fig.tight_layout()
return cm
def print_results(y_true, y_pred, target_names):
"""
打印相关结果
:param y_true: 期望输出,1-d array
:param y_pred: 实际输出,1-d array
:param target_names: 各类别名称
:return: 打印结果
"""
overall_accuracy = accuracy_score(y_true, y_pred)
print('\n----- overall_accuracy: {0:f} -----'.format(overall_accuracy))
cm = confusion_matrix(y_true, y_pred)
for i in range(len(target_names)):
print(target_names[i] + ':')
Se = cm[i][i]/np.sum(cm[i])
Pp = cm[i][i]/np.sum(cm[:, i])
print(' Se = ' + str(Se))
print(' P+ = ' + str(Pp))
print('--------------------------------------')