-
Notifications
You must be signed in to change notification settings - Fork 1
/
clozestyle_bert.py
158 lines (127 loc) · 4.94 KB
/
clozestyle_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# -*- coding: utf-8 -*-
# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
import torch
!pip install pytorch-pretrained-bert
from google.colab import files
uploaded = files.upload()
!ls
!pwd
"""# Bert Pytorch- Question and Answer (Edited)
# WORD2VEC gensim
"""
!pip3 install --upgrade gensim
import gensim.downloader as api
info = api.info()
model = api.load("word2vec-google-news-300")
similarity = model.similarity('java','software')
print(similarity)
print(type(similarity))
print(similarity.item())
print(type(similarity.item()))
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
import gensim.downloader as api
import time
import sys
#load data
#desk = input("Please input the location of the data file: "'\n')
#filename = input("Please input the name of the data file: "'\n')
f = open('/content/sample3.txt', "r",encoding='UTF-8')#
list = f.readlines()[0:21]
start_time = time.time()
for i in range(len(list)):
list[i] = list[i].strip()
list[20] = list[20].split('\t',3)
list[20][0] = list[20][0].replace("XXXXX","_")
pre_text = ''
for i in range(20):
pre_text+= list[i]+' '
target_text = list[20][0]
choices = list[20][3].split("|")
print(pre_text,'\n',target_text,'\n',choices,'\n')
# Load pre-trained model with masked language model head
bert_version = 'bert-large-uncased'
model = BertForMaskedLM.from_pretrained(bert_version)
# Preprocess text
text = pre_text + target_text
# Prevent RuntimeError
if len(text)>2000:
pre_text = ''
for i in range(10,20):
pre_text += list[i]+' '
text = pre_text + target_text
print('After decreasing the sentences... ''\n')
tokenizer = BertTokenizer.from_pretrained(bert_version)
tokenized_text = tokenizer.tokenize(text)
mask_positions = []
for i in range(len(tokenized_text)):
if tokenized_text[i] == '_':
tokenized_text[i] = '[MASK]'
mask_positions.append(i)
# Predict missing words from left to right
model.eval()
predicted_token = ''
for mask_pos in mask_positions:
# Convert tokens to vocab indices
token_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([token_ids])
# print('tokens_tensor: ''\n',tokens_tensor)
# Call BERT to predict token at this position
try:
predictions = model(tokens_tensor)[0, mask_pos]
except RuntimeError:
sys.exit('Oops! Sorry for your input 1-20 articles are too long. Try to decrease your sentences.')
# print('Oops! Sorry for your input 1-20 articles are too long. Try to decrease your sentences.')
# break
else:
predictions = model(tokens_tensor)[0, mask_pos]
# print("type.predictions:",type(predictions))
# print("predictions:"'\n',predictions)
predicted_index = torch.argmax(predictions).item()
# print('predicted_index''\n',predicted_index)
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print('predicted_token:''\n',predicted_token)
# for i in range(10):
# predicted_token1 = tokenizer.convert_ids_to_tokens([predicted_index-i])[0]
# print('predicted_token(+1):',predicted_token1)
# Update text
tokenized_text[mask_pos] = predicted_token
for mask_pos in mask_positions:
tokenized_text[mask_pos] = "_" + tokenized_text[mask_pos] + "_"
result = ' '.join(tokenized_text).replace(' ##', '').replace(pre_text,'')
result21 = target_text.replace('_',tokenized_text[mask_pos])
print('After predicting: ''\n', result21)
answer_number = -1
for answer_index in range(10):
# use word2vec to find the most similar answer if the prediction is not in the selection
if predicted_token in choices:
if predicted_token == choices[answer_index]:
answer_number = answer_index
break
else:
info = api.info()
model = api.load("word2vec-google-news-300")
similarity_list = []
for similarity_index in range(10):
try:
similarity = model.similarity(predicted_token,choices[similarity_index])
except KeyError:
similarity = 0.0000000001
else:
similarity = model.similarity(predicted_token,choices[similarity_index])
similarity = similarity.item() #.item(): convert np.float to float type
similarity_list.append(similarity)
print('similarity_list: ''\n',similarity_list)
most_similar = max(similarity_list)
most_similar_index = similarity_list.index(max(similarity_list))
print('Most similar answer is: ''\n',most_similar_index+1,' ',choices[most_similar_index],' ',most_similar)
answer_number = most_similar_index
predicted_token = choices[most_similar_index]
break
print('The Answer is: ''\n', answer_number+1,' ', predicted_token)
print("--- %s seconds ---" % (time.time() - start_time))