-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_model_tc.py
108 lines (106 loc) · 4.25 KB
/
get_model_tc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from keras.models import load_model
# from semeval.datasets import pre_deal
import pre_deal
import pre_deal_bert
from bert_serving.client import BertClient
from nltk import tokenize
import os
from itertools import groupby
# from semeval.datasets import KMP
import KMP
import numpy as np
import random
# list_tc = ['Appeal_to_Authority', 'Appeal_to_fear-prejudice', 'Bandwagon,Reductio_ad_hitlerum',
# # 'Black-and-White_Fallacy',
# # 'Causal_Oversimplification', 'Doubt', 'Exaggeration,Minimisation', 'Flag-Waving', 'Loaded_Language',
# # 'Name_Calling,Labeling', 'Repetition', 'Slogans', 'Thought-terminating_Cliches',
# # 'Whataboutism,Straw_Men,Red_Herring']
max_len = 1000
# bc = BertClient(ip='222.19.197.230', port=5555, port_out=5556, check_version=False)
model = load_model('softmax_mode_1.h5')
test_text = pre_deal_bert.get_test_textVector()
# test_vectors = bc.encode(test_text)
test_vectors = np.load("glove_test_300d.npy")
# np.save("test_case_512.npy", test_vectors)
test_predict_gailv = model.predict(test_vectors)
print("--------------show---------------")
test_pred = model.predict(test_vectors).argmax(-1)
print("test_pred")
print(test_pred)
print(test_pred.shape)
# test_pred_jieguo = []
# for i in range(0, 75):
# test_pred_jieguo.append([])
#
# for i in range(0, 75):
# for j in range(0, max_len):
# if test_predict_gailv[i][j][1] > 0.17:
# test_pred_jieguo[i].append(1)
# else:
# test_pred_jieguo[i].append(0)
#
# texts_token = []
# for i in range(0, len(test_text)):
# texts_token.append(tokenize.word_tokenize(test_text[i]))
# # end = 0
# filename = 'a.txt'
# f = open(filename, 'w', encoding='utf-8')
# labels = []
# list_labels = os.listdir("dev-articles")
# labels_tag = {} # 存储每篇文章的分词
# for i in range(0, len(list_labels)):
# labels_tag[list_labels[i][7:16]] = []
#
# # 获得测试集的word区间
# for j in range(0, 75):
# text_index = []
# for i in range(0, max_len):
# if (test_pred_jieguo[j][i] == 1):
# text_index.append(i)
# fun = lambda x: x[1] - x[0]
# for k, g in groupby(enumerate(text_index), fun):
# l1 = [j for i, j in g] # 连续数字的列表
# if len(l1) > 1:
# scop = str(min(l1)) + '-' + str(max(l1)) # 将连续数字范围用"-"连接
# else:
# scop = l1[0]
# labels_tag[list_labels[j][7:16]].append(min(l1))
# labels_tag[list_labels[j][7:16]].append(max(l1))
# print("----------------------")
# if (min(l1) == max(l1) and min(l1) > 400):
# li = texts_token[j][min(l1):]
# print(texts_token[j][min(l1):])
# else:
# li = texts_token[j][min(l1):max(l1)]
# print(texts_token[j][min(l1):max(l1)])
#
# list2 = [str(i) for i in li] # 使用列表推导式把列表中的单个元素全部转化为str类型
# list3 = ' '.join(list2) # 把列表中的元素放在空串中,元素间用空格隔开
# if (list3 == ''):
# pass
# else:
# a = KMP.KMP_algorithm(test_text[j], list3)
# if (a == -1):
# list_gai = str(texts_token[j][min(l1)])
# a = KMP.KMP_algorithm(test_text[j], list_gai) # 开始位置
# print(list_labels[j][7:16])
# print("值为:" + str(a))
# b = a + len(list3)
# print("结束值为:" + str(b))
# # str_1 = random.choice(list_tc)
# f.write(list_labels[j][7:16] + '\t' + str(a) + '\t' + str(b) + '\n')
# else:
# print(list_labels[j][7:16])
# print("值为:" + str(a))
# b = a + len(list3)
# print("结束值为:" + str(b))
# # str_1 = random.choice(list_tc)
# f.write(list_labels[j][7:16] + '\t' + str(a) + '\t' + str(b) + '\n')
# # a = KMP.KMP_algorithm(test_text[j],list3)
# # print("值为:"+a)
# print(list3)
# print(str(min(l1)))
# print(str(max(l1)))
# print("连续数字范围:{}".format(scop))
#
# f.close()