/
README.Rmd
356 lines (285 loc) · 14.9 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
---
title: "Implementing the Two-Stage Variational Autoencoder for Spatiotemporal Data"
author: "Samuel I. Berchuck"
date: "`r Sys.Date()`"
output: github_document
---
*This document provides supplementary code for the paper "Scalable Modeling of Spatiotemporal Data using the Variational Autoencoder: an Application in Glaucoma",* https://arxiv.org/abs/1908.09195.
*Also, see a clinical implementation of the variational autoencoder on bioRxiv by Berchuck et al (2019): "Estimating Rates of Progression and Predicting Future Visual Fields in Glaucoma Using a Deep Variational Autoencoder",* https://www.biorxiv.org/content/10.1101/652487v1.abstract.
## Loading Packages
We begin by loading the `keras` package and assigning the backend functions.
```{r, echo = FALSE}
###Start with a clean space
rm(list = ls())
```
```{r}
library(keras)
K <- keras::backend()
```
We can't include our visual field data from the manuscript, so we will use the MNIST digits data that comes with `keras`. We begin by loading the data and then creating training, validation, and test datasets.
```{r}
###Load MNIST data from keras
mnist <- dataset_mnist()
###Assign training/validation and test datasets + scale to (0, 1)
train_val_raw <- mnist$train$x / 255
test_raw <- mnist$test$x / 255
###Separate training and validation
train_raw <- train_val_raw[1:50000, , ]
val_raw <- train_val_raw[50001:60000, , ]
```
Before defining the variational autoencoder, we resize the MNIST images so that they are `12 x 12` dimensional, just like our visual fields. The resizing is performed through a simple spline interpolation based on the method from https://stackoverflow.com/questions/35786744/resizing-image-in-r. The function to resize is as follows.
```{r}
resize <- function(img, new_width = 12, new_height = 12) {
new_img <- apply(img, 2, function(y) {return(spline(y, n = new_height)$y)})
new_img <- t(apply(new_img, 1, function(y) {return(spline(y, n = new_width)$y)}))
new_img[new_img < 0] <- 0
return(new_img)
}
```
Now, we apply the function to the MNIST digits.
```{r}
###Data declarations
train_resized <- array(dim = c(dim(train_raw)[1], 12, 12))
val_resized <- array(dim = c(dim(val_raw)[1], 12, 12))
test_resized <- array(dim = c(dim(test_raw)[1], 12, 12))
###Resize the images
for (s in 1:dim(train_raw)[1]) train_resized[s, , ] <- resize(train_raw[s, , ])
for (s in 1:dim(val_raw)[1]) val_resized[s, , ] <- resize(val_raw[s, , ])
for (s in 1:dim(test_raw)[1]) test_resized[s, , ] <- resize(test_raw[s, , ])
###Resizing changed scale, so rescale
train_resized_scale <- train_resized / max(train_resized)
val_resized_scale <- val_resized / max(val_resized)
test_resized_scale <- test_resized / max(test_resized)
```
Finally, make the data have four dimensions as `keras` requires (sample size, height, width, depth).
```{r}
###Create data objects for keras model
x_train <- array(dim = c(dim(train_resized_scale)[1], 12, 12, 1))
x_val <- array(dim = c(dim(val_resized_scale)[1], 12, 12, 1))
x_test <- array(dim = c(dim(test_resized_scale)[1], 12, 12, 1))
###Fill keras data objects
x_train[ , , , 1] <- train_resized_scale
x_val[ , , , 1] <- val_resized_scale
x_test[ , , , 1] <- test_resized_scale
```
We obviously do not suggest down-grading the quality of an image, but only are doing this so that the model architecture can be the same as in our analysis for the `12 x 12` dimensional visual fields. We can verify that the resizing was done properly, by visualizing an example image before and after resizing. Clearly, the new images are of poor quality, but they serve the purpose of our example.
```{r, fig.align="center", fig.width = 5.5, fig.height = 3}
par(mfcol = c(1, 2))
plot(as.raster(train_raw[1, , ]))
plot(as.raster(train_resized_scale[1, , ]))
```
## Model Architecture
We begin by setting the model parameters that will be used throughout the model architecture. In `keras` the arguments must be integers. In this code, we use a latent space with two dimension, which was chosen for visualization purposes only. We also set the seed using the `keras` function `use_session_with_seed`.
```{r}
###Model parameters
batch_size <- 100L
original_dim <- 12L
latent_dim <- 2L
intermediate_dim <- 32L
final_dim <- 64L
epochs <- 50L
strides <- 2L
ks <- 3L
###Set keras seed
keras::use_session_with_seed(54)
```
We then start model building be specifying the layers of the encoder. The encoder of the variational autoencoder is comprised of two 2D convolutional layers, a reshaping layer and a fully-connected dense layer. The result of the encoder is transforming the original `12 x 12` dimensional images into an `8` dimensional latent vector.
```{r}
x <- layer_input(shape = c(original_dim, original_dim, 1))
C1 <- layer_conv_2d(x, filters = intermediate_dim, kernel_size = ks, strides = strides, activation = "relu")
C2 <- layer_conv_2d(C1, filters = final_dim, kernel_size = ks, strides = strides, activation = "relu")
R1 <- layer_flatten(C2)
F1 <- layer_dense(R1, units = latent_dim)
z <- F1 # define the latent vector space
###Define the keras sequential model
encoder <- keras_model(x, z)
```
Unlike the vanilla variational autoencoder we do not require a transformation step to sample the latent space. As such, we proceed to defining the decoder. The decoder begins with a fully-connected dense layer, followed by a reshaping layer, and then a sequence of deconvolutions, which up-sample until the original `12 x 12` dimension is reached. All layers use a `3 x 3` kernel size and a stride of `2`, and the activation is a rectified non-linear unit transformation, except for F1 and D3, which used the identity and sigmoid activations, respectively.
```{r}
###Instantiate keras layer for decoder to be used later
F2 <- layer_dense(units = ks * ks * intermediate_dim, activation = "relu")
R2 <- layer_reshape(target_shape = c(ks, ks, intermediate_dim))
D1 <- layer_conv_2d_transpose(filters = final_dim, kernel_size = ks, strides = strides, padding = "same", activation = "relu")
D2 <- layer_conv_2d_transpose(filters = intermediate_dim, kernel_size = ks, strides = strides, padding = "same", activation = "relu")
D3 <- layer_conv_2d_transpose(filters = 1, kernel_size = ks, strides = 1, padding = "same", activation = "sigmoid")
###Define sequential model for decoder
z_1 <- layer_input(shape = latent_dim)
f2_1 <- F2(z_1)
r2_1 <- R2(f2_1)
d1_1 <- D1(r2_1)
d2_1 <- D2(d1_1)
x_pred_decoder <- D3(d2_1)
decoder <- keras_model(z_1, x_pred_decoder)
```
Finally, we define a sequential model for the entire variational autoencoder.
```{r}
###Define sequential model for full variational autoencoder
f2 <- F2(z)
r2 <- R2(f2)
d1 <- D1(r2)
d2 <- D2(d1)
x_pred <- D3(d2)
###Full variational autoencoder
vae <- keras_model(x, x_pred)
print(vae)
```
## Define the Loss
The variational autoencoder in our paper uses the maximum mean discrepancy (MMD) in place of the Kulback-Liebler divergance, so we define custom loss functions using the `keras` backend functions. We begin by defining the MMD, which is used as the regularization loss.
```{r}
###Kernel function used in MMD calculation
compute_kernel <- function(x, y) {
x_size <- k_shape(x)[1]
y_size <- k_shape(y)[1]
dim <- k_shape(x)[2]
tiled_x <- k_cast(K$tile(k_reshape(x, k_stack(list(x_size, 1L, dim))), k_stack(list(1L, y_size, 1L))), "float32")
tiled_y <- k_cast(K$tile(k_reshape(y, k_stack(list(1L, y_size, dim))), k_stack(list(x_size, 1L, 1L))), "float32")
k_exp(-k_mean(k_square(tiled_x - tiled_y), axis = 3) / k_cast(dim, "float32"))
}
###Function to compute MMD
compute_mmd <- function(x, y, sigma_sqr = 1) {
x_kernel <- compute_kernel(x, x)
y_kernel <- compute_kernel(y, y)
xy_kernel <- compute_kernel(x, y)
k_mean(x_kernel) + k_mean(y_kernel) - k_cast(2L, "float32") * k_mean(xy_kernel)
}
```
Finally, we define our entire loss function, which combines the regularization loss using MMD and the reconstruction loss, which is a simple squared distance. We also, define a custom metric to be followed during model fit.
```{r}
###Complete loss function
vae_loss <- function(x, x_pred) {
dimz <- k_shape(z)[1]
true_samples <- k_random_normal(shape = c(dimz, latent_dim), dtype = "float32")
loss_reg <- compute_mmd(true_samples, z)
loss_rec <- k_mean(k_square(k_cast(x, "float32") - k_cast(x_pred, "float32")))
loss_reg + loss_rec
}
###Reconstruction loss to be plotted during fitting
metric_rec <- custom_metric("rec", function(x, x_pred) {
k_mean(k_square(k_cast(x, "float32") - k_cast(x_pred, "float32")))
})
```
## Fitting the Model
The model was trained using the Adam optimizer, an extension of stochastic gradient descent, using 50 epochs and a batch size of 100; we used a learning rate of 1e-4.
```{r}
###Optimization routine
compile(vae, optimizer = optimizer_adam(1e-4), loss = vae_loss, metrics = c(metric_rec))
```
Now we present code for training the model, where the training epoch with the minimal validation loss was chosen as optimal. The training is only performed if `RERUN` is set to `TRUE`, otherwise previously trained weights are loaded.
```{r}
###Model training
filepath <- "./Raw/vae.h5"
RERUN <- FALSE # change to FALSE to use saved model weights
if (RERUN) {
cp_callback <- callback_model_checkpoint(filepath = filepath, save_weights_only = FALSE, save_best_only = TRUE)
vae.fit <- fit(vae,
x_train, x_train,
shuffle = TRUE,
epochs = epochs,
batch_size = batch_size,
validation_data = list(x_val, x_val),
metrics = c("loss", "rec"),
callbacks = list(cp_callback),
view_metrics = TRUE)
save(vae.fit, file = "./Raw/fit.RData")
save_model_hdf5(encoder, filepath = "./Raw/encoder.h5")
save_model_hdf5(decoder, filepath = "./Raw/decoder.h5")
} else {
load("./Raw/fit.RData")
vae <- load_model_hdf5(filepath = "./Raw/vae.h5", custom_objects = c(vae_loss = vae_loss, rec = metric_rec))
encoder <- load_model_hdf5(filepath = "./Raw/encoder.h5", custom_objects = c(vae_loss = vae_loss, rec = metric_rec))
decoder <- load_model_hdf5(filepath = "./Raw/decoder.h5", custom_objects = c(vae_loss = vae_loss, rec = metric_rec))
}
```
Visualizing the overall and reconstruction loss is straightforward.
```{r, fig.align="center", fig.width = 5.5, fig.height = 3}
plot(vae.fit)
```
## Exploring the trained model
Once training has completed, we can use the `encoder`, `decoder`, and `vae` using the `predict` function. Here, we encode all of the images in the test dataset and then generate decoded images for each set of latent variables. We can see that the trained model is working adequetly, as a reconstructed image is shown to be similar to its original digit.
```{r, fig.align="center", fig.width = 5.5, fig.height = 3}
###Obtain latent space
latent <- predict(encoder, x_test)
###Generate an image
gen <- predict(decoder, latent)
###Plot an image and its decoded image
par(mfcol = c(1, 2))
plot(as.raster(gen[3, , , ]))
plot(as.raster(x_test[3, , , ]))
```
## Visualizing the latent space
Begin by loading necessary packages.
```{r}
library(reshape2)
library(ggplot2)
library(scales)
```
Since we are using two latent dimensions, we can visualize the latent space across digit type. This does not perform particularly well because we have down-sized the digits, however a pattern still forms.
```{r, fig.align="center", fig.width = 5.5, fig.height = 5.5}
###Extract test labels from MNIST data
Label <- as.factor(mnist$test$y)
test_out <- data.frame(latent, Label)
###Plot latent space
p <- ggplot(mapping = aes(latent[, 1], latent[, 2], color = Label))
p <- p + geom_point()
p <- p + xlab("Latent Dimension: 1") + ylab("Latent Dimension: 2")
print(p)
```
Now, we can visualize the latent space using generated visual fields to see the data manifold in two-dimensions. Again, the visualization is not terribly clear due to the down-sizing of the MNIST digits for the training process.
```{r, fig.align="center", fig.width = 5.5, fig.height = 5.5}
###Create a dataset for visualization that is 20 x 20 digits.
n <- 20
img_size <- 12
Min <- apply(latent, 2, min)
Max <- apply(latent, 2, max)
grid_x <- seq(Min[1], Max[1], length.out = n)
grid_y <- seq(Min[2], Max[2], length.out = n)
rows <- NULL
for (i in 1:length(grid_x)) {
column <- NULL
for (j in 1:length(grid_y)) {
z_sample <- matrix(c(grid_x[i], grid_y[j]), ncol = 2)
pred <- predict(decoder, z_sample)[ , , , 1]
column <- rbind(column, pred[dim(pred)[1]:1, ])
}
rows <- cbind(rows, column)
}
###Visualize two dimensional data manifold
data <- melt(t(rows))
p <- ggplot(data, aes(Var1, Var2, fill = value))
p <- p + geom_tile()
p <- p + scale_fill_gradientn(colours = c("black", "white"), limits = c(0, 1))
p <- p + coord_fixed()
p <- p + theme(panel.background = element_blank(),
panel.border = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
plot.background = element_blank(),
axis.text.x = element_blank(),
axis.text.y = element_blank(),
axis.ticks = element_blank())
p <- p + labs(fill = "") + ylab("Latent Dimension: 2") + xlab("Latent Dimension: 1")
print(p)
```
## A Two-Stage Technique for Longitudinal Data
In order to implement the two-stage method for spatiotemporal data, we only need to use the trained `decoder` function. The MNIST data does not inherently have a longitudinal component, however we can manufacture one. Assume that for a subject we observe a digit each year, with the first four being as follows.
```{r, fig.align="center", fig.width = 5.5, fig.height = 3}
###Define data for longitudinal example
Time <- 1:4
latent_pred <- matrix(c(0, -0.5, -1, -1.5, -2.5, -2, -1.5, -1), ncol = 2)
###Plot decoded images that we assume are truly observed annually
img_pred <- predict(decoder, latent_pred)
par(mfcol = c(1, 4))
for (y in 1:4) plot(as.raster(img_pred[y, , , 1]))
```
Now, we can apply the simple two-stage technique by applying regression to each dimension of the latent space. This allows us to obtain a predicted image at year five.
```{r, fig.align="center", fig.width = 5.5, fig.height = 4}
###Fit linear regression independently to each latent dimension
mods <- apply(latent_pred, 2, function(x) lm(x ~ Time))
###Predict each dimension at the fifth year
latent_future <- matrix(unlist(lapply(mods, function(x) predict(x, newdata = data.frame(Time = 5)))), nrow = 1)
###Obtain the predicted image and plot
img_future <- predict(decoder, latent_future)[1 , , , 1]
plot(as.raster(img_future))
```
While the MNIST digits do not have an inherent longitudinal nature this demonstrates the two-stage technique. When applied to spatiotemporal data, like visual fields, movement through the manifold corresponds to glaucoma disease progression, and is therefore clinically valuable.
Parts of this code were adapted from the following fabulous resources about variational autoencoders in `R`: https://keras.rstudio.com/articles/examples/variational_autoencoder.html and https://blogs.rstudio.com/tensorflow/posts/2018-10-22-mmd-vae/.