-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
74 lines (59 loc) · 1.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
'''
Training code
Written by:
Simo Ryu
'''
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torch.optim as optim
from eq_dataset import EquationsDataset
from models import SOP
def train():
device = torch.device("cuda:0")
epochs = 1
batch_size = 64
lr = 1e-4
max_len = 128
#chars = list("0987654321-+*()^xyz") (case with three variables)
chars = list("0987654321-+*()^xy")
n_vocab = len(chars) + 2
model = SOP(
d_model = 512,
n_head = 8,
num_layers = 6,
n_vocab = n_vocab,
max_len = max_len,
chars = chars,
device = device
)
opt = optim.AdamW(model.parameters(), lr = lr, weight_decay = 1e-10)
dataset = EquationsDataset(max_len = max_len, chars = chars)
dl = DataLoader(dataset, shuffle= True, batch_size= batch_size, drop_last= True, num_workers = 3)
criterion = nn.CrossEntropyLoss()
model.to(device)
for epoch in range(1, epochs + 1):
pbar = tqdm(dl)
tot_loss = 0
cnt = 0
for (x, yin, yout) in pbar:
x = x.to(device)
yin = torch.cat([torch.ones(batch_size, 1) * (n_vocab - 1), yin], dim = 1).long()
yin = yin.to(device)
yout = yout.to(device)
y_pred = model(x, yin)
loss = criterion(y_pred.view(-1, n_vocab - 1), yout.view(-1))
model.zero_grad()
loss.backward()
opt.step()
tot_loss += loss.item()
cnt += 1
pbar.set_description(f"current loss : {tot_loss/cnt:.5f}")
eq = "2*y^4-2*y^3-y^2+1"
ans = "(1-y^2)^2+(-y^2+y)^2"
ral = model.toSOP(eq, gen_len = max_len - 1)
print(f'Epoch {epoch} : Loss : {tot_loss/cnt :.5f}, Example : {ral[0]}')
torch.save(model, "model.dat")
if __name__ == "__main__":
train()