-
Notifications
You must be signed in to change notification settings - Fork 1
/
chart_helper.pyx
143 lines (117 loc) · 5.73 KB
/
chart_helper.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
cimport numpy as np
from numpy cimport ndarray
cimport cython
ctypedef np.float32_t DTYPE_t
ORACLE_PRECOMPUTED_TABLE = {}
@cython.boundscheck(False)
def decode(int force_gold, int sentence_len, np.ndarray[DTYPE_t, ndim=3] label_scores_chart, int is_train, gold, label_vocab):
cdef DTYPE_t NEG_INF = -np.inf
cdef np.ndarray[DTYPE_t, ndim=3] label_scores_chart_copy = label_scores_chart.copy()
cdef np.ndarray[DTYPE_t, ndim=2] value_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.float32)
cdef np.ndarray[int, ndim=2] split_idx_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
cdef np.ndarray[int, ndim=2] best_label_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
cdef int length
cdef int left
cdef int right
cdef np.ndarray[DTYPE_t, ndim=1] label_scores_for_span
cdef int oracle_label_index
cdef DTYPE_t label_score
cdef int argmax_label_index
cdef DTYPE_t left_score
cdef DTYPE_t right_score
cdef int best_split
cdef int split_idx # Loop variable for splitting
cdef DTYPE_t split_val # best so far
cdef DTYPE_t max_split_val
cdef int label_index_iter
cdef np.ndarray[int, ndim=2] oracle_label_chart
cdef np.ndarray[int, ndim=2] oracle_split_chart
if is_train or force_gold:
if gold not in ORACLE_PRECOMPUTED_TABLE:
oracle_label_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
oracle_split_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
for length in range(1, sentence_len + 1):
for left in range(0, sentence_len + 1 - length):
right = left + length
oracle_label_chart[left, right] = label_vocab.index(gold.oracle_label(left, right))
if length == 1:
continue
oracle_splits = gold.oracle_splits(left, right)
oracle_split_chart[left, right] = min(oracle_splits)
if not gold.nocache:
ORACLE_PRECOMPUTED_TABLE[gold] = oracle_label_chart, oracle_split_chart
else:
oracle_label_chart, oracle_split_chart = ORACLE_PRECOMPUTED_TABLE[gold]
for length in range(1, sentence_len + 1):
for left in range(0, sentence_len + 1 - length):
right = left + length
if is_train or force_gold:
oracle_label_index = oracle_label_chart[left, right]
if force_gold:
label_score = label_scores_chart_copy[left, right, oracle_label_index]
best_label_chart[left, right] = oracle_label_index
else:
if is_train:
label_scores_chart_copy[left, right, oracle_label_index] -= 1
if length < sentence_len:
argmax_label_index = 0
else:
argmax_label_index = 1
label_score = label_scores_chart_copy[left, right, argmax_label_index]
for label_index_iter in range(1, label_scores_chart_copy.shape[2]):
if label_scores_chart_copy[left, right, label_index_iter] > label_score:
argmax_label_index = label_index_iter
label_score = label_scores_chart_copy[left, right, label_index_iter]
best_label_chart[left, right] = argmax_label_index
if is_train:
label_score += 1
if length == 1:
value_chart[left, right] = label_score
continue
if force_gold:
best_split = oracle_split_chart[left, right]
else:
best_split = left + 1
split_val = NEG_INF
for split_idx in range(left + 1, right):
max_split_val = value_chart[left, split_idx] + value_chart[split_idx, right]
if max_split_val > split_val:
split_val = max_split_val
best_split = split_idx
value_chart[left, right] = label_score + value_chart[left, best_split] + value_chart[best_split, right]
split_idx_chart[left, right] = best_split
cdef int num_tree_nodes = 2 * sentence_len - 1
cdef np.ndarray[int, ndim=1] included_i = np.empty(num_tree_nodes, dtype=np.int32)
cdef np.ndarray[int, ndim=1] included_j = np.empty(num_tree_nodes, dtype=np.int32)
cdef np.ndarray[int, ndim=1] included_label = np.empty(num_tree_nodes, dtype=np.int32)
cdef int idx = 0
cdef int stack_idx = 1
# technically, the maximum stack depth is smaller than this
cdef np.ndarray[int, ndim=1] stack_i = np.empty(num_tree_nodes + 5, dtype=np.int32)
cdef np.ndarray[int, ndim=1] stack_j = np.empty(num_tree_nodes + 5, dtype=np.int32)
stack_i[1] = 0
stack_j[1] = sentence_len
cdef int i, j, k
while stack_idx > 0:
i = stack_i[stack_idx]
j = stack_j[stack_idx]
stack_idx -= 1
included_i[idx] = i
included_j[idx] = j
included_label[idx] = best_label_chart[i, j]
idx += 1
if i + 1 < j:
k = split_idx_chart[i, j]
stack_idx += 1
stack_i[stack_idx] = k
stack_j[stack_idx] = j
stack_idx += 1
stack_i[stack_idx] = i
stack_j[stack_idx] = k
cdef DTYPE_t running_total = 0.0
for idx in range(num_tree_nodes):
running_total += label_scores_chart[included_i[idx], included_j[idx], included_label[idx]]
cdef DTYPE_t score = value_chart[0, sentence_len]
cdef DTYPE_t augment_amount = round(score - running_total)
return score, included_i.astype(int), included_j.astype(int), included_label.astype(int), augment_amount