/
sim_utils.py
81 lines (69 loc) · 2.15 KB
/
sim_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import io
import numpy as np
import torch
def get_wordmap(textfile):
words={}
We = []
f = io.open(textfile, 'r', encoding='utf-8')
lines = f.readlines()
if len(lines[0].split()) == 2:
lines.pop(0)
ct = 0
for (n,i) in enumerate(lines):
word = i.split(' ', 1)[0]
vec = i.split(' ', 1)[1].split(' ')
j = 0
v = []
while j < len(vec):
v.append(float(vec[j]))
j += 1
words[word] = ct
ct += 1
We.append(v)
return words, np.array(We)
def get_minibatches_idx(n, minibatch_size, shuffle=False):
idx_list = np.arange(n, dtype="int32")
if shuffle:
np.random.shuffle(idx_list)
minibatches = []
minibatch_start = 0
for i in range(n // minibatch_size):
minibatches.append(idx_list[minibatch_start:
minibatch_start + minibatch_size])
minibatch_start += minibatch_size
if (minibatch_start != n):
# Make a minibatch out of what is left
minibatches.append(idx_list[minibatch_start:])
return zip(range(len(minibatches)), minibatches)
def max_pool(x, lengths, gpu):
out = torch.FloatTensor(x.size(0), x.size(2)).zero_()
if gpu >= 0:
out = out.cuda()
for i in range(len(lengths)):
out[i] = torch.max(x[i][0:lengths[i]], 0)[0]
return out
def mean_pool(x, lengths, gpu):
out = torch.FloatTensor(x.size(0), x.size(2)).zero_()
if gpu >= 0:
out = out.cuda()
for i in range(len(lengths)):
out[i] = torch.mean(x[i][0:lengths[i]], 0)
return out
def lookup(words, w):
w = w.lower()
if w in words:
return words[w]
class Example(object):
def __init__(self, sentence):
self.sentence = sentence.strip().lower()
self.embeddings = []
self.representation = None
def populate_embeddings(self, words):
sentence = self.sentence.lower()
arr = sentence.split()
for i in arr:
emb = lookup(words, i)
if emb:
self.embeddings.append(emb)
if len(self.embeddings) == 0:
self.embeddings.append(words['UUUNKKK'])