-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
99 lines (81 loc) · 4.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pdb
import json
import pickle
import tensorflow as tf
import numpy as np
from model import BIDAF
from preprocess import Squad_Dataset
from config import get_args
from evaluator import *
class Trainer():
def __init__(self, config, data, model, saver, sess, exp_name, writer):
self.config = config
self.data = data
self.model = model
self.sess = sess
self.saver = saver
self.exp_name = exp_name
self.writer = writer
self.loss = self.model.loss
self.train_opt = self.model.train_opt
self.global_step = self.model.global_step
def train(self, ):
tr_ema, tr_loss = [], []
for epoch in range(self.config.epochs):
print (" -------------------- Epoch %d is ongoing -------------------- \n" % (epoch))
for train_idx, (ques, cont_mat, ques_char, cont_char, ans_start, ans_stop, qa_id) in enumerate(self.data.tr_zip_list):
# loss, global_step, ce, arg_p1, arg_p2 = self.train_step(ques, cont_mat, ques_char, cont_char, ans_start, ans_stop)
loss, prob1, prob2 = self.train_step(ques, cont_mat, ques_char, cont_char, ans_start, ans_stop)
# tr_loss.append(ce)
if (train_idx+1) % self.config.print_step == 0:
print ('Epoch %d, train_step %d: loss %.4f \n' % (epoch, train_idx, loss))
self.saver.save(self.sess, "%s/%s" % (self.config.save_dir, self.exp_name))
print ('Successfully saved model\n')
pred_dict = dict()
for dev_idx, (ques, cont_mat, ques_char, cont_char, ans_start, ans_stop, qa_id) in enumerate(self.data.dev_zip_list):
loss, global_step, arg_p1, arg_p2 = self.evaluate(ques, cont_mat, ques_char, cont_char, ans_start, ans_stop)
pred1_lst = arg_p1.tolist()
pred2_lst = arg_p2.tolist()
for index_idx in range(len(pred1_lst)):
# answer_str = ''
need_decode = cont_mat[index_idx][pred1_lst[index_idx]:pred2_lst[index_idx]+1]
answer_str = ' '.join([self.data.idx2word[dec_idx] for dec_idx in need_decode])
pred_dict[qa_id[index_idx]] = answer_str
if (dev_idx+1) % self.config.print_step == 0:
print ('Epoch %d, dev_step %d \n' % (epoch, dev_idx))
results = evaluate(self.data.dev_file, pred_dict)
with open('./out/predictions_%d.json' % (epoch), 'w', encoding='utf-8') as fp:
json.dump(pred_dict, fp, indent=2)
with open('./out/results_%d.json' % (epoch), 'w', encoding='utf-8') as fp:
json.dump(results, fp)
if self.config.mode == 'test':
pass
def evaluate(self, ques, cont_mat, ques_char, cont_char, ans_start, ans_stop):
feed_dict = self.create_feed_dict(ques, cont_mat, ques_char, cont_char, ans_start, ans_stop)
_, loss, global_step, prob1, prob2 = self.sess.run([self.train_opt, self.loss, self.global_step, \
self.model.arg_p1, self.model.arg_p2], feed_dict=feed_dict)
return loss, global_step, prob1, prob2
def train_step(self, ques, cont_mat, ques_char, cont_char, ans_start, ans_stop):
feed_dict = self.create_feed_dict(ques, cont_mat, ques_char, cont_char, ans_start, ans_stop)
loss, _, global_step, prob1, prob2= self.sess.run([self.loss, self.train_opt, self.global_step, self.model.prob1, self.model.prob2], feed_dict=feed_dict)
summary = tf.Summary(value=[tf.Summary.Value(tag="loss", simple_value=loss)])
self.writer.add_summary(summary, global_step)
return loss, prob1, prob2
def create_feed_dict(self, ques, cont_mat, ques_char, cont_char, ans_start, ans_stop):
if self.config.mode == 'train':
feed_dict = {self.model.ques_word:ques,
self.model.cont_word:cont_mat,
self.model.ques_char:ques_char,
self.model.cont_char:cont_char,
self.model.answer_start:ans_start,
self.model.answer_stop:ans_stop}
elif self.config.mode == 'dev':
feed_dict = {self.model.ques_word:ques,
self.model.cont_word:cont_mat,
self.model.ques_char:ques_char,
self.model.cont_char:cont_char,
self.model.answer_start:ans_start,
self.model.answer_stop:ans_stop}
elif self.config.mode == 'test':
pass
return feed_dict