-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_ss_vae.py
157 lines (128 loc) · 6.06 KB
/
mnist_ss_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import os
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import modules.ss_vae_conv as ss_vae
import visualize as vis
# Parameters
data_dir = 'data/MNIST'
learning_rate = 0.001
image_size = torch.Size([1, 28, 28])
im_sz = int(np.prod(image_size))
num_classes = 10
# Helpers
def icdf(v):
return torch.erfinv(2 * torch.Tensor([float(v)]) - 1) * np.sqrt(2)
def horzArr(a):
''' finagles a (num_images)x1x(a)x(b) array into a (a)x(b*num_images) array '''
a = np.squeeze(a) # get rid of extra 1d
a = np.split(a,np.size(a,0)) # separate images
a= np.concatenate(a,2) # put together horizontally
a = np.squeeze(a) # get rid of extra 1d
return a
def train(vae, data_loader, fixed_x, fixed_y):
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
iter_per_epoch = len(data_loader)
L_vec = [] # store losses for plotting later
for epoch in range(args.epochs):
for batch_idx, (images, labels) in enumerate(data_loader):
images = Variable(images.view(-1, im_sz).to(DEVICE))
labels = Variable(labels.to(DEVICE))
out, z_params, pi = vae(images) # forward pass
L = vae.loss(images, labels, out, z_params, pi) # compute loss
optimizer.zero_grad()
L.backward() # backward pass
optimizer.step() # parameter update
L_vec.append(L.item())
if batch_idx % 100 == 0:
print("Epoch[%d/%d], Step [%d/%d], Total Loss: %.4f " %(epoch, args.epochs-1, batch_idx, iter_per_epoch, L.item()))
# visualise progress:
reconst_images, _, _ = vae(fixed_x) # another forward pass on fixed inputs
reconst_images = reconst_images.view(reconst_images.size(0), *image_size) # reshape
torchvision.utils.save_image(reconst_images.data.cpu(), os.path.join(args.res, 'reconst_images_%d.png' %(epoch)))
torch.save(vae.state_dict(), args.save) # save model to disk
plt.plot(L_vec)
plt.savefig(os.path.join(args.res, 'loss.png'))
plt.show()
## TRAIN
def main():
z_sz = args.z_sz
# Load Data
dataset = torchvision.datasets.MNIST(root=data_dir,
train=True,
transform=transforms.ToTensor(),
download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=args.batch_sz,
shuffle=True)
# For debugging
data_iter = iter(data_loader)
fixed_x_save, fixed_y_save = next(data_iter) # batch images and labels
fixed_x = fixed_x_save.view(fixed_x_save.size(0), im_sz)
fixed_x = Variable(fixed_x.to(DEVICE))
fixed_y = Variable(fixed_y_save.to(DEVICE))
torchvision.utils.save_image(fixed_x_save, os.path.join(args.res, 'real_images.png'))
vae = ss_vae.SS_VAE(img_size=image_size, device=DEVICE, z_sz=z_sz, batch_size=args.batch_sz)
if args.load is None:
train(vae, data_loader, fixed_x, fixed_y)
else:
vae.load_state_dict(torch.load(args.load, map_location=lambda storage, loc: storage))
# Save reconstructed image
reconst_images, _, _ = vae(fixed_x)
reconst_images = reconst_images.view(-1, *image_size)
torchvision.utils.save_image(reconst_images.data.cpu(), os.path.join(args.res, 'reconst_images.png'))
# we want to visualise what happens when we take a batch example,
# which has label y, and sample from SS-VAE with label y' \in [0,9]
# l is a vector of [y, 0, 1, 2, ..., 9]
l = np.arange(-1,num_classes)
l[0] = fixed_y_save[0].item()
b = np.eye(num_classes)[l] # one-hot vectors 11 x 10 from l
labels = torch.from_numpy(b).float().to(DEVICE)
labels = Variable(labels)
res = torch.zeros(num_classes+1, num_classes+1, im_sz)
for k in range(num_classes+1):
x = fixed_x[k].view(1, -1) # take a batch example
x = x.expand(num_classes+1, -1) # duplicate along rows
z_params, pi = vae.encoder(x)
z = vae.reparam_z(z_params) # sample a latent z for x
out = vae.sample(z, labels) # fix z, and vary labels
out[0,:] = fixed_x[k,:]
res[k] = out
res = res.view(-1, *image_size)
torchvision.utils.save_image(res.data.cpu(), os.path.join(args.res, 'cvae.png'), nrow=num_classes+1)
if args.walk:
# set up visualizer-- see how changing z alters changes images over all labels y
f = lambda z:horzArr(vae.sample(
# to get z: convert to torch, replicate 11 times
Variable(torch.from_numpy(z).float().to(DEVICE)).view(1,-1).expand(num_classes+1,-1),
# format to be 11 28x28 images horizontally arranged
labels).view(-1,*image_size).data.numpy())
v = vis.Visualizer(f,z_sz=z_sz)
v.visualize()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-load', help='path of model to load')
parser.add_argument('-walk', action='store_true', help='displays GUI to walk embedding space')
parser.add_argument('-save', help='path of model to save')
parser.add_argument('-res', help='path to save figures')
parser.add_argument("-batch_sz", type=int,
help="how many in batch", default=100)
parser.add_argument("-z_sz", type=int,
help="latent size", default=20)
parser.add_argument("-epochs", type=int,
help="how many epochs", default=10)
parser.add_argument("-device", help="specify device to run on", default="cpu")
args = parser.parse_args()
DEVICE = torch.device('cpu')
if args.device == 'cuda' and torch.cuda.is_available():
DEVICE = torch.device('cuda')
main()