-
Notifications
You must be signed in to change notification settings - Fork 15
/
training.py
65 lines (51 loc) · 2.13 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy as np
import editdistance
import matplotlib.pyplot as plt
import tqdm
def train(model, optimizer, train_loader, state):
epoch, n_epochs, train_steps = state
losses = []
cers = []
# t = tqdm.tqdm(total=min(len(train_loader), train_steps))
t = tqdm.tqdm(train_loader)
model.train()
for batch in t:
t.set_description("Epoch {:.0f}/{:.0f} (train={})".format(epoch, n_epochs, model.training))
loss, _, _, _ = model.loss(batch)
losses.append(loss.item())
# Reset gradients
optimizer.zero_grad()
# Compute gradients
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2)
optimizer.step()
t.set_postfix(loss='{:05.3f}'.format(loss.item()), avg_loss='{:05.3f}'.format(np.mean(losses)))
t.update()
return model, optimizer
# print(" End of training: loss={:05.3f} , cer={:03.1f}".format(np.mean(losses), np.mean(cers)*100))
def evaluate(model, eval_loader):
losses = []
accs = []
t = tqdm.tqdm(eval_loader)
model.eval()
with torch.no_grad():
for batch in t:
t.set_description(" Evaluating... (train={})".format(model.training))
loss, logits, labels, alignments = model.loss(batch)
preds = logits.detach().cpu().numpy()
# acc = np.sum(np.argmax(preds, -1) == labels.detach().cpu().numpy()) / len(preds)
acc = 100 * editdistance.eval(np.argmax(preds, -1), labels.detach().cpu().numpy()) / len(preds)
losses.append(loss.item())
accs.append(acc)
t.set_postfix(avg_acc='{:05.3f}'.format(np.mean(accs)), avg_loss='{:05.3f}'.format(np.mean(losses)))
t.update()
align = alignments.detach().cpu().numpy()[:, :, 0]
# Uncomment if you want to visualise weights
# fig, ax = plt.subplots(1, 1)
# ax.pcolormesh(align)
# fig.savefig("data/att.png")
print(" End of evaluation : loss {:05.3f} , acc {:03.1f}".format(np.mean(losses), np.mean(accs)))
# return {'loss': np.mean(losses), 'cer': np.mean(accs)*100}
if __name__ == '__main__':
train(1, 50)