/
main.py
102 lines (84 loc) · 4.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from tqdm import tqdm
import torch
import torch.utils
import torch.distributions
from torch.utils.data import DataLoader
import torch.nn.functional as F
from dataset import PiDataset
from model import VariationalAutoencoder
from utils import *
class PiGenerator:
def __init__(self, model, latent_dim, epochs, result_path, batch_size, seq_len):
super(PiGenerator, self).__init__()
self.model = model
self.latent_dim = latent_dim
self.epochs = epochs
self.result_path = result_path
self.z = torch.randn([batch_size, seq_len, latent_dim]).float().to(device)
num_param = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f'#params: {num_param}')
def loss_function(self, x, x_hat, mean, log_var):
"""
Calculate the reconstruction loss and KLD loss
"""
recons_loss = F.mse_loss(x_hat, x)
kld_loss = - 0.5 * torch.sum(1+log_var-mean.pow(2)-log_var.exp())
return recons_loss, kld_loss
def train(self, data, gen_every_epochs):
"""
Train the VAE by update the weights iteratively
"""
opt = torch.optim.Adam(self.model.parameters())
total_loss, total_kl, total_rc, beta = [], [], [], 0
pbar = tqdm(range(self.epochs), desc='Epoch: ')
for epoch in pbar:
for _, x in enumerate(data):
x = x.float().to(device)
self.model.train()
opt.zero_grad()
x_hat, mean, log_var = self.model(x.float())
recons_loss, kld_loss = self.loss_function(x.squeeze(), x_hat.squeeze(), mean, log_var)
loss = (1-beta) * recons_loss + beta * kld_loss
loss.backward()
opt.step()
cur_total_loss = loss.item()
cur_kl_loss = kld_loss.item()
cur_rc_loss = recons_loss.item()
total_loss.append(cur_total_loss)
total_kl.append(cur_kl_loss)
total_rc.append(cur_rc_loss)
pbar.set_description('Total Loss: {}| KLD Loss: {}| RC Loss: {}'.format(round(cur_total_loss, 2), round(cur_kl_loss, 2), round(cur_rc_loss, 2)), refresh=True)
if gen_every_epochs:
self.generator(epoch)
beta += 1/self.epochs
return total_loss, total_kl, total_rc
def generator(self, cur_epoch):
"""
Generate the image by random sampling the latent representation and putting into the decoder
"""
self.model.eval()
with torch.no_grad():
outputs = self.model.decoder(self.z).cpu().numpy()
x_coor, y_coor = xy_rescaling(xy_coor=outputs[:, 0:2, :])
r_value, g_value, b_value = rgb_rescaling(rgb_values=outputs[:, 2:5, :])
generate_img(x_coor, y_coor, r_value, g_value, b_value, self.result_path, cur_epoch)
if __name__ == '__main__':
latent_dim, epochs, batch_size, device, result_path, gen_every_epochs, num_workers, retrain, num_head = get_config()
# Load the data
data = PiDataset()
pi_dataloader = DataLoader(dataset=data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# Call the class to construct an object
model = VariationalAutoencoder(input_dim=1, latent_dim=latent_dim, num_head=num_head).to(device)
print(f'model: {model}')
pi_generator = PiGenerator(model=model, latent_dim=latent_dim, epochs=epochs, result_path=result_path, batch_size=batch_size, seq_len=data.get_seq_len())
if retrain:
# Train VAE
total_loss, total_kl, total_rc = pi_generator.train(data=pi_dataloader, gen_every_epochs=gen_every_epochs)
draw_loss_curve(total_num_epoch=epochs, total_loss=total_loss, total_kl=total_kl, total_rc=total_rc, result_path=result_path)
# Save the model weight
torch.save(model.state_dict(), './vae.pth')
else:
# Load the model weight
model.load_state_dict(torch.load('./vae.pth'))
# Generate the image
pi_generator.generator(cur_epoch=epochs)