-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_custom.py
142 lines (108 loc) · 4.81 KB
/
test_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import warnings
warnings.filterwarnings('ignore')
from utils import utils
from utils.constants import *
from data.custom_transforms import ExtractFFNNFeatures, ExtractMFCC, ToThreeChannels, ToTensor
from torch import nn
from torch.utils.data import DataLoader, default_collate
import torchaudio
from torchvision import transforms
import numpy as np
import datetime
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
from models.definitions.cnn_model import ConvNeuralNetwork
from models.definitions.ffnn_model import FFNNNeuralNetwork
class CustomTest():
def __init__(self,config):
self.config = config
self.y_true = []
self.y_pred = []
self.test_loss = 1
self.accuracy = 0
self.path = config['custom_test_path']
self.name = config['model_name']
self.transform_audio()
def transform_audio(self):
# Transform audio into universal format
self.audio_sample, self.sample_rate = torchaudio.load(self.path)
self.audio_sample = self.resample(self.audio_sample, self.sample_rate, SAMPLE_RATE)
self.audio_sample = self.toMono(self.audio_sample)
self.audio_samples = self.cutDown(self.audio_sample)
# Transform the sample
samples = [{
'audio': audio_sample,
'sample_rate': self.sample_rate,
'input': [],
'label': []
} for audio_sample in self.audio_samples]
sample_transform = None
if self.config['model_name'] == SupportedModels.FFNN.name:
ffnn_transform = transforms.Compose([
ExtractFFNNFeatures(),
ToTensor()
])
sample_transform = ffnn_transform
elif self.config['model_name'] == SupportedModels.CNN.name or self.config['model_name'] == SupportedModels.VGG.name:
cnn_transform = transforms.Compose([
ExtractMFCC(),
ToThreeChannels(),
ToTensor()
])
sample_transform = cnn_transform
transformed_samples = [sample_transform(sample) for sample in samples]
# Save model input
self.sample_input = [torch.unsqueeze(transformed_sample['input'], dim=0) for transformed_sample in transformed_samples]
def startTest(self):
print("Testing model")
predicted_classes = []
for val_fold in range(1, K_FOLD+1):
# print("Fold: ", val_fold)
# Model
model_path = SAVED_MODEL_PATH + DATASET + '/' + self.config['model_name'] + "_fold" + str(val_fold) + '.pt'
model = torch.load(model_path, map_location=torch.device(DEVICE))
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
# Start testing
predicted_class = self.testLoop(model, loss_fn)
predicted_classes.append(predicted_class)
predicted_class = max(set(predicted_classes), key=predicted_classes.count)
print("Classfied as: ", list(filter(lambda x: URBAN_SOUND_8K_LABEL_MAPPING[x] == predicted_class, URBAN_SOUND_8K_LABEL_MAPPING))[0])
def testLoop(self, model, loss_fn):
size = len(self.audio_sample)
model.eval()
predicted_classes = []
with torch.no_grad():
for ind in range(size):
X = self.sample_input[ind].to(DEVICE)
pred = model(X)
pred = pred.argmax(1).cpu()
predicted_classes.append(pred[0].item())
predicted_class = max(set(predicted_classes), key=predicted_classes.count)
return predicted_class
def printAccuracy(self):
print(f'Accuracy of model: {(self.accuracy/K_FOLD):>0.2f}%')
def resample(self, audio_sample, sample_rate, target_sample_rate):
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
audio_sample = resampler(audio_sample)
return audio_sample
def toMono(self, audio_sample):
if audio_sample.shape[0] > 1:
audio_sample = torch.mean(audio_sample, dim=0, keepdim=True)
return audio_sample
def cutDown(self, audio_sample):
audio_samples = []
for ind in range(50):
sample = audio_sample[: , ind*SAMPLE_SIZE:(ind+1)*SAMPLE_SIZE]
if sample.shape[1] == SAMPLE_SIZE:
audio_samples.append(sample)
else:
break
return audio_samples
def padRight(self, audio_sample):
length = audio_sample.shape[1]
if length < SAMPLE_SIZE:
to_pad = SAMPLE_SIZE - length
last_dim_padding = (0, to_pad)
audio_sample = nn.functional.pad(audio_sample, last_dim_padding)
return audio_sample