-
Notifications
You must be signed in to change notification settings - Fork 0
/
TEM_model.py
47 lines (41 loc) · 1.93 KB
/
TEM_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
class TEM(nn.Module):
def __init__(self, embedsize=64, hiddensize=128):
super(TEM, self).__init__()
batchNormalization = False
self.linear = nn.Linear(in_features=1024, out_features=embedsize)
self.tmp_conv0 = nn.Conv1d(in_channels=embedsize, out_channels=hiddensize, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(True)
self.tmp_conv1 = nn.Conv1d(in_channels=hiddensize, out_channels=hiddensize, kernel_size=3, stride=1, padding=1)
self.tmp_conv2 = nn.Conv1d(in_channels=hiddensize, out_channels=3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
# tem = nn.Sequential()
# tem.add_module('conv1d{0}'.format(0),
# nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1))
# tem.add_module('relu{0}'.format(0), nn.ReLU(True))
#
# tem.add_module('conv1d{0}'.format(1),
# nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1))
# tem.add_module('relu{0}'.format(1), nn.ReLU(True))
#
# tem.add_module('conv1d{0}'.format(2),
# nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1, stride=1, padding=0))
# self.cnn = tem
#
# self.sigmoid = nn.Sigmoid()
def forward(self, X_feature):
fc1 = self.linear(X_feature)
relu1 = self.relu(fc1)
permute_fc = relu1.permute(0, 2, 1)
tmp_conv0 = self.tmp_conv0(permute_fc)
relu1 = self.relu(tmp_conv0)
tmp_conv1 = self.tmp_conv1(relu1)
relu2 = self.relu(tmp_conv1)
tmp_conv2 = self.tmp_conv2(relu2)
sigmoid_output = self.sigmoid(0.1 * tmp_conv2)
# conv = self.cnn(X_feature)
# sigmoid_output = self.sigmoid(0.1 * conv)
return sigmoid_output