/
model.py
50 lines (37 loc) · 1.62 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, global_add_pool, SAGPooling
class Encoder(nn.Module):
def __init__(self, in_channels, hidden_channels):
super(Encoder, self).__init__()
self.hidden_channels = hidden_channels
self.conv = GCNConv(in_channels, self.hidden_channels)
self.prelu = nn.PReLU(self.hidden_channels)
def forward(self, x, edge_index):
x1 = self.conv(x, edge_index)
x1 = self.prelu(x1)
return x1
class Pool(nn.Module):
def __init__(self, in_channels, ratio=1.0):
super(Pool, self).__init__()
self.sag_pool = SAGPooling(in_channels, ratio)
self.lin1 = torch.nn.Linear(in_channels * 2, in_channels)
def forward(self, x, edge, batch, type='mean_pool'):
if type == 'mean_pool':
return global_mean_pool(x, batch)
elif type == 'max_pool':
return global_max_pool(x, batch)
elif type == 'sum_pool':
return global_add_pool(x, batch)
elif type == 'sag_pool':
x1, _, _, batch, _, _ = self.sag_pool(x, edge, batch=batch)
return global_mean_pool(x1, batch)
class Scorer(nn.Module):
def __init__(self, hidden_size):
super(Scorer, self).__init__()
self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
def forward(self, input1, input2):
output = torch.sigmoid(torch.sum(input1 * torch.matmul(input2, self.weight), dim = -1))
return output