-
Notifications
You must be signed in to change notification settings - Fork 9
/
data.py
152 lines (138 loc) · 6.55 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import random
from copy import deepcopy
import numpy as np
import operator
from utils import Triple
# Change the head of a triple randomly,
# without checking whether it is a false negative sample.
def corrupt_head_raw(quadruple, entityTotal):
newQuadruple = deepcopy(quadruple)
oldHead = quadruple.s
while True:
newHead = random.randrange(entityTotal)
if newHead != oldHead:
break
newQuadruple.s = newHead
return newQuadruple
# Change the tail of a triple randomly,
# without checking whether it is a false negative sample.
def corrupt_tail_raw(quadruple, entityTotal):
newQuadruple = deepcopy(quadruple)
oldTail = newQuadruple.o
while True:
newTail = random.randrange(entityTotal)
if newTail != oldTail:
break
newQuadruple.o = newTail
return newQuadruple
# Change the head of a triple randomly,
# with checking whether it is a false negative sample.
# If it is, regenerate.
def corrupt_head_filter(quadruple, entityTotal, quadrupleDict):
newQuadruple = deepcopy(quadruple)
while True:
newHead = random.randrange(entityTotal)
if (newHead, newQuadruple.o, newQuadruple.r) not in quadrupleDict:
break
newQuadruple.s = newHead
return newQuadruple
# Change the tail of a triple randomly,
# with checking whether it is a false negative sample.
# If it is, regenerate.
def corrupt_tail_filter(quadruple, entityTotal, quadrupleDict):
newQuadruple = deepcopy(quadruple)
while True:
newTail = random.randrange(entityTotal)
if (newQuadruple.s, newTail, newQuadruple.r) not in quadrupleDict:
break
newQuadruple.o = newTail
return newQuadruple
# Split the tripleList into #num_batches batches
# def getBatchList(tripleList, num_batches):
# batchSize = len(tripleList) // num_batches
# batchList = [0] * num_batches
# for i in range(num_batches - 1):
# batchList[i] = tripleList[i * batchSize : (i + 1) * batchSize]
# batchList[num_batches - 1] = tripleList[(num_batches - 1) * batchSize : ]
# return batchList
def getBatchList(tripleList, batch_size):
num_batches = len(tripleList) // batch_size + 1
batchList = [0] * num_batches
for i in range(num_batches - 1):
batchList[i] = tripleList[i * batch_size : (i + 1) * batch_size]
batchList[num_batches - 1] = tripleList[(num_batches - 1) * batch_size : ]
return batchList
def getFourElements(quadrupleList):
headList = [quadruple.s for quadruple in quadrupleList]
tailList = [quadruple.o for quadruple in quadrupleList]
relList = [quadruple.r for quadruple in quadrupleList]
timeList = [quadruple.t for quadruple in quadrupleList]
return headList, tailList, relList, timeList
def getThreeElements(tripleList):
headList = [triple.s for triple in tripleList]
tailList = [triple.o for triple in tripleList]
relList = [triple.r for triple in tripleList]
return headList, tailList, relList
# Use all the tripleList,
# and generate negative samples by corrupting head or tail with equal probabilities,
# without checking whether false negative samples exist.
def getBatch_raw_all(quadrupleList, entityTotal, mult_num = 1):
newQuadrupleList = [corrupt_head_raw(quadruple, entityTotal) if random.random() < 0.5
else corrupt_tail_raw(quadruple, entityTotal) for quadruple in quadrupleList]
if mult_num > 1:
for i in range(0, mult_num-1):
newQuadrupleList2 = [corrupt_head_raw(quadruple, entityTotal) if random.random() < 0.5
else corrupt_tail_raw(quadruple, entityTotal) for quadruple in quadrupleList]
newQuadrupleList.extend(newQuadrupleList2)
ps, po, pr, pt = getFourElements(quadrupleList)
ns, no, nr, nt = getFourElements(newQuadrupleList)
return ps, po, pr, pt, ns, no, nr, nt
# Use all the tripleList,
# and generate negative samples by corrupting head or tail with equal probabilities,
# with checking whether false negative samples exist.
def getBatch_filter_all(quadrupleList, entityTotal, quadrupleDict, mult_num = 1):
newQuadrupleList = [corrupt_head_filter(quadruple, entityTotal, quadrupleDict) if random.random() < 0.5
else corrupt_tail_filter(quadruple, entityTotal, quadrupleDict) for quadruple in quadrupleList]
if mult_num > 1:
for i in range(0, mult_num-1):
newQuadrupleList2 = [corrupt_head_filter(quadruple, entityTotal, quadrupleDict) if random.random() < 0.5
else corrupt_tail_filter(quadruple, entityTotal, quadrupleDict) for quadruple in
quadrupleList]
newQuadrupleList.extend(newQuadrupleList2)
ps, po, pr, pt = getFourElements(quadrupleList)
ns, no, nr, nt = getFourElements(newQuadrupleList)
return ps, po, pr, pt, ns, no, nr, nt
# Sample a batch of #batchSize triples from tripleList,
# and generate negative samples by corrupting head or tail with equal probabilities,
# without checking whether false negative samples exist.
def getBatch_raw_random(quadrupleList, batchSize, entityTotal):
oldQuadrupleList = random.sample(quadrupleList, batchSize)
newQuadrupleList = [corrupt_head_raw(quadruple, entityTotal) if random.random() < 0.5
else corrupt_tail_raw(quadruple, entityTotal) for quadruple in oldQuadrupleList]
ph, po, pr, pt = getFourElements(oldQuadrupleList)
nh, no, nr, nt = getFourElements(newQuadrupleList)
return ph, po, pr, pt, nh, no, nr, nt
# Sample a batch of #batchSize triples from tripleList,
# and generate negative samples by corrupting head or tail with equal probabilities,
# with checking whether false negative samples exist.
def getBatch_filter_random(quadrupleList, batchSize, entityTotal, quadrupleDict):
oldQuadrupleList = random.sample(quadrupleList, batchSize)
newQuadrupleList = [corrupt_head_filter(quadruple, entityTotal, quadrupleDict) if random.random() < 0.5
else corrupt_tail_filter(quadruple, entityTotal, quadrupleDict) for quadruple in oldQuadrupleList]
ph, po, pr, pt = getFourElements(oldQuadrupleList)
nh, no, nr, nt = getFourElements(newQuadrupleList)
return ph, po, pr, pt, nh, no, nr, nt
def getTimestampBatchList(quadrupleList):
batchList = []
tmpList = []
preTimestamp = []
for i in range(len(quadrupleList)):
if not operator.eq(quadrupleList[i].t, preTimestamp):
if len(preTimestamp) != 0:
batchList.append(deepcopy(tmpList))
preTimestamp = quadrupleList[i].t
tmpList = []
tmpList.append(quadrupleList[i])
batchList.append(deepcopy(tmpList))
return batchList