-
Notifications
You must be signed in to change notification settings - Fork 0
/
17-nn.qmd
1078 lines (900 loc) · 48 KB
/
17-nn.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Neural networks and deep learning
\index{classification!neural networks}
Neural networks (NN) can be considered to be nested additive (or even ensemble) models where explanatory variables are combined, and transformed through an activation function like a logistic. These transformed combinations are added recursively to yield class predictions. They are considered to be black box models, but there is a growing demand for interpretability. Although interpretability is possible, it can be unappealing to understand a complex model constructed to tackle a difficult classification task. Nevertheless, this is the motivation for the explanation of visualisation for NN models in this chapter.
In the simplest form, we might write the equation for a NN as
$$
\hat{y} = f(x) = a_0+\sum_{h=1}^{s}
w_{0h}\phi(a_h+\sum_{i=1}^{p} w_{ih}x_i)
$$ where $s$ indicates the number of nodes in the hidden (middle layer), and $\phi$ is a choice of activation function. In a simple situation where $p=3$, $s=2$, and linear output layer, the model could be written as:
$$
\begin{aligned}
\hat{y} = a_0+ & w_{01}\phi(a_1+w_{11}x_1+w_{21}x_2+w_{31}x_3) +\\
& w_{02}\phi(a_2+w_{12}x_1+w_{22}x_2+w_{32}x_3)
\end{aligned}
$$ which is a combination of two (linear) models, each of which could be examined for their role in making predictions.
In practice, a model may have many nodes, and several hidden layers, a variety of activation functions, and regularisation modifications. One should keep in mind the principle of parsimony is important when applying NNs, because it is tempting to make an overly complex, and thus over-parameterised, construction. Fitting NNs is still problematic. One would hope that fitting produces a stable result, whatever the starting seed the same parameter estimates are returned. However, this is not the case, and different, sometimes radically different, results are routinely obtained after each attempted fit [@wickham2015].
For these examples we use the software `keras` [@keras] following the installation and tutorial details at <https://tensorflow.rstudio.com/tutorials/>. Because it is an interface to python it can be tricky to install. If this is a problem, the example code should be possible to convert to use `nnet` [@VR02] or `neuralnet` [@neuralnet]. We will use the penguins data to illustrate the fitting, because it makes it easier to understand the procedures and the fit. However, a NN is like using a jackhammer instead of a trowel to plant a seedling, more complicated than necessary to build a good classification model for this data.
## Setting up the model
\index{classification!ANN architecture}
A first step is to decide how many nodes the NN architecture should have, and what activation function should be used. To make these decisions, ideally you already have some knowledge of the shapes of class clusters. For the penguins classification, we have seen that it contains three elliptically shaped clusters of roughly the same size. This suggests two nodes in the hidden layer would be sufficient to separate three clusters (@fig-nn-diagram). Because the shapes of the clusters are convex, using linear activation ("relu") will also be sufficient. The model specification is as follows:
```{r eval=FALSE}
#| echo: true
library(keras)
tensorflow::set_random_seed(211)
# Define model
p_nn_model <- keras_model_sequential()
p_nn_model %>%
layer_dense(units = 2, activation = 'relu',
input_shape = 4) %>%
layer_dense(units = 3, activation = 'softmax')
p_nn_model %>% summary
loss_fn <- loss_sparse_categorical_crossentropy(
from_logits = TRUE)
p_nn_model %>% compile(
optimizer = "adam",
loss = loss_fn,
metrics = c('accuracy')
)
```
Note that `tensorflow::set_random_seed(211)` sets the seed for the model fitting so that we can obtain the same result to discuss later. It needs to be set before the model is defined in the code. The model will also be saved in order to diagnose and make predictions.
![Network architecture for the model on the penguins data. The round nodes indicate original or transformed variables, and each arrow connecting these is represented as one of the weights $w_{ih}$ in the definition. The boxes indicate the additive constant entering the nodes, and the corresponding arrows represent the terms $a_h$.](images/nn-diagram.png){#fig-nn-diagram align="center"}
```{r eval=FALSE}
#| echo: false
# tidymodels approach does not allow extracting weights
library(keras)
library(tidymodels)
# Define model
p_nn_spec <-
mlp(hidden_units = 2,
activation = 'relu',
penalty = 0,
dropout = 0,
epochs = 500) %>%
# This model can be used for classification or regression, so set mode
set_mode("classification") %>%
set_engine("keras")
set.seed(821)
p_split <- penguins_sub %>%
select(bl:species) %>%
initial_split(prop = 2/3,
strata=species)
p_train <- training(p_split)
p_test <- testing(p_split)
set.seed(834)
p_nn_model <- p_nn_spec %>% fit(species~., p_train)
p_nn_pred <- bind_cols(
predict(p_nn_model, p_test),
predict(p_nn_model, p_test, type = "prob")
)
p_nn_wgts <- keras::get_weights(p_nn_model, trainable=TRUE)
```
## Checking the training/test split
\index{classification!training/test split}
Splitting the data into training and test is an essential way to protect against overfitting, for most classifiers, but especially so for the copiously parameterised NNs. The model specified for the penguins data with only two nodes is unlikely to be overfitted, but it is nevertheless good practice to use a training set for building and a test set for evaluation.
`r ifelse(knitr::is_html_output(), '@fig-p-split-html', '@fig-p-split-pdf')` shows the tour being used to examine the split into training and test samples for the penguins data. Using random sampling, particularly stratified by group, should result the two sets being very similar, as can be seen here. It does happen that several observations in the test set are on the extremes of their class cluster, so it could be that the model makes errors in the neighbourhoods of these points.
```{r echo=knitr::is_html_output()}
#| message: false
# Split the data intro training and testing
library(ggthemes)
library(dplyr)
library(tidyr)
library(rsample)
library(ggbeeswarm)
library(tidymodels)
library(tourr)
load("data/penguins_sub.rda") # from mulgar book
set.seed(821)
p_split <- penguins_sub %>%
select(bl:species) %>%
initial_split(prop = 2/3,
strata=species)
p_train <- training(p_split)
p_test <- testing(p_split)
# Check training and test split
p_split_check <- bind_rows(
bind_cols(p_train, type = "train"),
bind_cols(p_test, type = "test")) %>%
mutate(type = factor(type))
```
```{r echo=knitr::is_html_output(), eval=FALSE}
#| code-fold: true
#| code-summary: "Code to run tours"
animate_xy(p_split_check[,1:4],
col=p_split_check$species,
pch=p_split_check$type,
shapeset=c(16,1))
animate_xy(p_split_check[,1:4],
guided_tour(lda_pp(p_split_check$species)),
col=p_split_check$species,
pch=p_split_check$type,
shapeset=c(16,1))
render_gif(p_split_check[,1:4],
grand_tour(),
display_xy(
col=p_split_check$species,
pch=p_split_check$type,
shapeset=c(16,1),
cex=1.5,
axes="bottomleft"),
gif_file="gifs/p_split.gif",
frames=500,
loop=FALSE
)
render_gif(p_split_check[,1:4],
guided_tour(lda_pp(p_split_check$species)),
display_xy(
col=p_split_check$species,
pch=p_split_check$type,
shapeset=c(16,1),
cex=1.5,
axes="bottomleft"),
gif_file="gifs/p_split_guided.gif",
frames=500,
loop=FALSE
)
```
::: {.content-visible when-format="html"}
::: {#fig-p-split-html layout-ncol="2"}
![Grand tour](gifs/p_split.gif){#fig-split-grand fig-alt="FIX ME" width="300"}
![Guided tour](gifs/p_split_guided.gif){#fig-split-guided fig-alt="FIX ME" width="300"}
Evaluating the training/test split, where we expect that the two samples should roughly match. There are a few observations in the test set that are on the outer edges of the clusters, which will likely result in the model making an error in these regions. However, the two samples roughly match.
:::
:::
::: {.content-visible when-format="pdf"}
::: {#fig-p-split-pdf layout-ncol="2"}
![Grand tour](images/p_split.png){fig-alt="FIX ME" width="220"}
![Guided tour](images/p_split_guided.png){fig-alt="FIX ME" width="220"}
Evaluating the training/test split, where we expect that the two samples should roughly match. There are a few observations in the test set that are on the outer edges of the clusters, which will likely result in the model making an error in these regions. However, the two samples roughly match.
:::
:::
## Fit the model
\index{classification!Fitting a NN}
The data needs to be specially formatted for the model fitted using `keras`. The explanatory variables need to be provided as a `matrix`, and the categorical response needs to be separate, and specified as a `numeric` variable, beginning with 0.
```{r}
# Data needs to be matrix, and response needs to be numeric
p_train_x <- p_train %>%
select(bl:bm) %>%
as.matrix()
p_train_y <- p_train %>% pull(species) %>% as.numeric()
p_train_y <- p_train_y-1 # Needs to be 0, 1, 2
p_test_x <- p_test %>%
select(bl:bm) %>%
as.matrix()
p_test_y <- p_test %>% pull(species) %>% as.numeric()
p_test_y <- p_test_y-1 # Needs to be 0, 1, 2
```
The specified model is reasonably simple, four input variables, two nodes in the hidden layer and a three column binary matrix for output. This corresponds to 5+5+3+3+3=19 parameters.
```{r echo=FALSE}
#| message: false
library(keras)
p_nn_model <- load_model_tf("data/penguins_cnn")
p_nn_model
```
```{r eval=FALSE}
#| message: false
# Fit model
p_nn_fit <- p_nn_model %>% keras::fit(
x = p_train_x,
y = p_train_y,
epochs = 200,
verbose = 0
)
```
```{r eval=FALSE, echo=FALSE}
# Check
p_nn_model %>% evaluate(p_test_x, p_test_y, verbose = 0)
plot(p_nn_fit)
keras::get_weights(p_nn_model, trainable=TRUE)
```
Because we set the random number seed we will get the same fit each time the code provided here is run. However, if the model is re-fit without setting the seed, you will see that there is a surprising amount of variability in the fits. Setting `epochs = 200` helps to usually get a good fit. One expects that `keras` is reasonably stable so one would not expect the huge array of fits as observed in @wickham2015 using `nnet`. That this can happen with the simple model used here reinforces the notion that fitting of NN models is fiddly, and great care needs to be taken to validate and diagnose the fit.
::: {.content-visible when-format="html"}
::: info
Fitting NN models is fiddly, and very different fitted models can result from restarts, parameter choices, and architecture.
:::
:::
::: {.content-visible when-format="pdf"}
```{=tex}
\infobox{Fitting NN models is fiddly, and very different fitted models can result from restarts, parameter choices, and architecture.
}
```
:::
```{r echo=knitr::is_html_output()}
#| code-fold: true
library(keras)
library(ggplot2)
library(colorspace)
# load fitted model
p_nn_model <- load_model_tf("data/penguins_cnn")
```
The fitted model that we have chosen as the final one has reasonably small loss and high accuracy. Plots of loss and accuracy across epochs showing the change during fitting can be plotted, but we don't show them here, because they are generally not very interesting.
```{r}
p_nn_model %>% evaluate(p_test_x, p_test_y, verbose = 0)
```
The model object can be saved for later use with:
```{r eval=FALSE}
save_model_tf(p_nn_model, "data/penguins_cnn")
```
## Extracting model components
\index{classification!hidden layers}
::: {.content-visible when-format="html"}
::: info
View the individual node models to understand how they combine to produce the overall model.
:::
:::
::: {.content-visible when-format="pdf"}
```{=tex}
\infobox{View the individual node models to understand how they combine to produce the overall model.
}
```
:::
Because nodes in the hidden layers of NNs are themselves (relatively simple regression) models, it can be interesting to examine these to understand how the model is making it's predictions. Although it's rarely easy, most software will allow the coefficients for the models at these nodes to be extracted. With the penguins NN model there are two nodes, so we can extract the coefficients and plot the resulting two linear combinations to examine the separation between classes.
```{r}
# Extract hidden layer model weights
p_nn_wgts <- keras::get_weights(p_nn_model, trainable=TRUE)
p_nn_wgts
```
The linear coefficients for the first node in the model are `r round(p_nn_wgts[[1]][,1], 2)`, and the second node in the model are `r round(p_nn_wgts[[1]][,2], 2)`. We can use these like we used the linear discriminants in LDA to make a 2D view of the data, where the model is separating the three species. The constants `r round(p_nn_wgts[[2]], 2)` are not important for this. They are only useful for drawing the location of the boundaries between classes produced by the model.
These two sets of model coefficients provide linear combinations of the original variables. Together, they define a plane on which the data is projected to view the classification produced by the model. Ideally, though this plane should be defined using an orthonormal basis otherwise the shape of the data distribution might be warped. So we orthonormalise this matrix before computing the data projection.
```{r}
# Orthonormalise the weights to make 2D projection
p_nn_wgts_on <- tourr::orthonormalise(p_nn_wgts[[1]])
p_nn_wgts_on
```
```{r echo=knitr::is_html_output()}
#| code-fold: false
#| label: fig-hidden-layer
#| fig-cap: "Plot of the data in the linear combinations from the two nodes in the hidden layer. The three species are clearly different, although with some overlap between all three. A main issue to notice is that there isn't a big gap between Gentoo and the other species, which we know is there based on our data exploration done in other chapters. This suggests this fitted model is sub-optimal."
#| fig-alt: FIXME
#| fig-width: 5
#| fig-height: 4
#| out-width: 80%
# Hidden layer
p_train_m <- p_train %>%
mutate(nn1 = as.matrix(p_train[,1:4]) %*%
as.matrix(p_nn_wgts_on[,1], ncol=1),
nn2 = as.matrix(p_train[,1:4]) %*%
matrix(p_nn_wgts_on[,2], ncol=1))
# Now add the test points on.
p_test_m <- p_test %>%
mutate(nn1 = as.matrix(p_test[,1:4]) %*%
as.matrix(p_nn_wgts_on[,1], ncol=1),
nn2 = as.matrix(p_test[,1:4]) %*%
matrix(p_nn_wgts_on[,2], ncol=1))
p_train_m <- p_train_m %>%
mutate(set = "train")
p_test_m <- p_test_m %>%
mutate(set = "test")
p_all_m <- bind_rows(p_train_m, p_test_m)
ggplot(p_all_m, aes(x=nn1, y=nn2,
colour=species, shape=set)) +
geom_point() +
scale_colour_discrete_divergingx(palette="Zissou 1") +
scale_shape_manual(values=c(16, 1)) +
theme_minimal() +
theme(aspect.ratio=1)
```
@fig-hidden-layer shows the data projected into the plane determined by the two linear combinations of the two nodes in the hidden layer. Training and test sets are indicated by empty and solid circles. The three species are clearly different but there is some overlap or confusion for a few penguins. The most interesting aspect to learn is that there is no big gap between the Gentoo and other species, which we know exists in the data. The model has not found this gap, and thus is likely to unfortunately and erroneously confuse some Gentoo penguins, particularly with Adelie.
What we have shown here is a process to use the models at the nodes of the hidden layer to produce a reduced dimensional space where the classes are best separated, at least as determined by the model. The process will work in higher dimensions also.
When there are more nodes in the hidden layer than the number of original variables it means that the space is extended to achieve useful classifications that need more complicated non-linear boundaries. The extra nodes describe the non-linearity. @wickham2015 provides a good illustration of this in 2D. The process of examining each of the node models can be useful for understanding this non-linear separation, also in high dimensions.
## Examining predictive probabilities
\index{classification!predictive probabilities}
When the predictive probabilities are returned by a model, as is done by this NN, we can use a ternary diagram for three class problems, or high-dimensional simplex when there are more classes to examine the strength of the classification. This done in the same way that was used for the votes matrix from a random forest in @sec-votes.
```{r}
# Predict training and test set
p_train_pred <- p_nn_model %>%
predict(p_train_x, verbose = 0)
p_train_pred_cat <- levels(p_train$species)[
apply(p_train_pred, 1,
which.max)]
p_train_pred_cat <- factor(
p_train_pred_cat,
levels=levels(p_train$species))
table(p_train$species, p_train_pred_cat)
p_test_pred <- p_nn_model %>%
predict(p_test_x, verbose = 0)
p_test_pred_cat <- levels(p_test$species)[
apply(p_test_pred, 1,
which.max)]
p_test_pred_cat <- factor(
p_test_pred_cat,
levels=levels(p_test$species))
table(p_test$species, p_test_pred_cat)
```
```{r echo=FALSE, eval=FALSE}
# predict() causes the problem, use p_nn_model(p_test_x) instead
```
```{r echo=knitr::is_html_output()}
#| code-fold: true
# Set up the data to make the ternary diagram
# Join data sets
colnames(p_train_pred) <- c("Adelie", "Chinstrap", "Gentoo")
colnames(p_test_pred) <- c("Adelie", "Chinstrap", "Gentoo")
p_train_pred <- as_tibble(p_train_pred)
p_train_m <- p_train_m %>%
mutate(pspecies = p_train_pred_cat) %>%
bind_cols(p_train_pred) %>%
mutate(set = "train")
p_test_pred <- as_tibble(p_test_pred)
p_test_m <- p_test_m %>%
mutate(pspecies = p_test_pred_cat) %>%
bind_cols(p_test_pred) %>%
mutate(set = "test")
p_all_m <- bind_rows(p_train_m, p_test_m)
# Add simplex to make ternary
library(geozoo)
proj <- t(geozoo::f_helmert(3)[-1,])
p_nn_v_p <- as.matrix(p_all_m[,c("Adelie", "Chinstrap", "Gentoo")]) %*% proj
colnames(p_nn_v_p) <- c("x1", "x2")
p_nn_v_p <- p_nn_v_p %>%
as.data.frame() %>%
mutate(species = p_all_m$species,
set = p_all_m$set)
simp <- geozoo::simplex(p=2)
sp <- data.frame(cbind(simp$points), simp$points[c(2,3,1),])
colnames(sp) <- c("x1", "x2", "x3", "x4")
sp$species = sort(unique(penguins_sub$species))
```
```{r echo=knitr::is_html_output()}
#| code-fold: true
#| fig-width: 4
#| fig-height: 3
#| out-width: 70%
#| fig-cap: "Ternary diagram for the three groups of the predictive probabilities of both training ans test sets. From what we already know about the penguins data this fit is not good. Both Chinstrap and Gentoo penguins are confused with Adelie, or at risk of it. Gentoo is very well-separated from the other two species when several variables are used, and this fitted model is blind to it. One useful finding is that there is little difference between training and test sets, so the model has not been over-fitted."
# Plot it
ggplot() +
geom_segment(data=sp, aes(x=x1, y=x2, xend=x3, yend=x4)) +
geom_text(data=sp, aes(x=x1, y=x2, label=species),
nudge_x=c(-0.1, 0.15, 0),
nudge_y=c(0.05, 0.05, -0.05)) +
geom_point(data=p_nn_v_p, aes(x=x1, y=x2,
colour=species,
shape=set),
size=2, alpha=0.5) +
scale_color_discrete_divergingx(palette="Zissou 1") +
scale_shape_manual(values=c(19, 1)) +
theme_map() +
theme(aspect.ratio=1, legend.position = "right")
```
::: {.content-visible when-format="html"}
::: info
If the training and test sets look similar when plotted in the model space then the model is not suffering from over-fitting.
:::
:::
::: {.content-visible when-format="pdf"}
```{=tex}
\infobox{If the training and test sets look similar when plotted in the model space then the model is not suffering from over-fitting.
}
```
:::
## Local explanations
\index{classification!local explanations} \index{classification!XAI}
It especially important to be able to interpret or explain a model, even more so when the model is complex or black-box'y. A good resource for learning about the range of methods is @iml. Local explanations provide some information about variables that are important for making the prediction for a particular observation. The method that we use here is Shapley value, as computed using the `kernelshap` package [@kernelshap].
```{r eval=FALSE}
# Explanations
# https://www.r-bloggers.com/2022/08/kernel-shap/
library(kernelshap)
library(shapviz)
p_explain <- kernelshap(
p_nn_model,
p_train_x,
bg_X = p_train_x,
verbose = FALSE
)
p_exp_sv <- shapviz(p_explain)
save(p_exp_sv, file="data/p_exp_sv.rda")
```
A Shapley value for an observation indicates how the variable contributes to the model prediction for that observation, relative to other variables. It is an average, computed from the change in prediction when all combinations of presence or absence of other variables. In the computation, for each combination, the prediction is computed by substituting absent variables with their average value, like one might do when imputing missing values.
@fig-shapley-pcp shows the Shapley values for Gentoo observations (both training and test sets) in the penguins data, as a parallel coordinate plot. The values for the single misclassified Gentoo penguin (in the training set) is coloured orange. Overall, the Shapley values don't vary much on `bl`, `bd` and `fl` but they do on `bm`. The effect of other variables is seems to be only important for `bm`.
For the misclassified penguin, the effect of `bm` for all combinations of other variables leads to a decline in predicted value, thus less confidence in it being a Gentoo. In contrast, for this same penguin when considering the effect of `bl` the predicted value increases on average.
```{r echo=knitr::is_html_output()}
#| code-fold: true
load("data/p_exp_sv.rda")
p_exp_gentoo <- p_exp_sv$Class_3$S
p_exp_gentoo <- p_exp_gentoo %>%
as_tibble() %>%
mutate(species = p_train$species,
pspecies = p_train_pred_cat,
) %>%
mutate(error = ifelse(species == pspecies, 0, 1))
```
```{r echo=knitr::is_html_output()}
#| eval: false
#| code-fold: true
#| label: fig-shapley-dot
#| fig-width: 4
#| fig-height: 3
#| out-width: 80%
#| fig-cap: "SHAP values focused on Gentoo class, for each variable. The one misclassified penguin (orange) has a much lower value for body mass, suggesting that this variable is used differently for the prediction than for other penguins."
#| fig-alt: FIXME
p_exp_gentoo %>%
filter(species == "Gentoo") %>%
pivot_longer(bl:bm, names_to="var", values_to="shap") %>%
mutate(var = factor(var, levels=c("bl", "bd", "fl", "bm"))) %>%
ggplot(aes(x=var, y=shap, colour=factor(error))) +
geom_quasirandom(alpha=0.8) +
scale_colour_discrete_divergingx(palette="Geyser") +
#facet_wrap(~var) +
xlab("") + ylab("SHAP") +
theme_minimal() +
theme(legend.position = "none")
```
```{r echo=knitr::is_html_output()}
#| code-fold: true
#| label: fig-shapley-pcp
#| fig-width: 4
#| fig-height: 3
#| out-width: 80%
#| fig-cap: "SHAP values focused on Gentoo class, for each variable. The one misclassified penguin (orange) has a much lower value for body mass, suggesting that this variable is used differently for the prediction than for other penguins."
#| fig-alt: FIXME
library(ggpcp)
p_exp_gentoo %>%
filter(species == "Gentoo") %>%
pcp_select(1:4) %>%
ggplot(aes_pcp()) +
geom_pcp_axes() +
geom_pcp_boxes(fill="grey80") +
geom_pcp(aes(colour = factor(error)),
linewidth = 2, alpha=0.3) +
scale_colour_discrete_divergingx(palette="Geyser") +
xlab("") + ylab("SHAP") +
theme_minimal() +
theme(legend.position = "none")
```
If we examine the data [@fig-penguins-bl-bm-bd] the explanation makes some sense. The misclassified penguin has an unusually small value on `bm`. That the SHAP value for `bm` was quite different pointed to this being a potential issue with the model, particularly for this penguin. This penguin's prediction is negatively impacted by `bm` being in the model.
```{r echo=knitr::is_html_output()}
#| label: fig-penguins-bl-bm-bd
#| code-fold: true
#| fig-width: 6
#| fig-height: 6
#| out-width: 100%
#| fig-cap: "Plots of the data to help understand what the SHAP values indicate. The misclassified Gentoo penguin has an unusually low body mass value which makes it appear to be more like an Adelie penguin, particularly when considered in relation to it's bill length."
#| fig-alt: FIXME
library(patchwork)
# Check position on bm
shap_proj <- p_exp_gentoo %>%
filter(species == "Gentoo", error == 1) %>%
select(bl:bm)
shap_proj <- as.matrix(shap_proj/sqrt(sum(shap_proj^2)))
p_exp_gentoo_proj <- p_exp_gentoo %>%
rename(shap_bl = bl,
shap_bd = bd,
shap_fl = fl,
shap_bm = bm) %>%
bind_cols(as_tibble(p_train_x)) %>%
mutate(shap1 = shap_proj[1]*bl+
shap_proj[2]*bd+
shap_proj[3]*fl+
shap_proj[4]*bm)
sp1 <- ggplot(p_exp_gentoo_proj, aes(x=bm, y=bl,
colour=species,
shape=factor(1-error))) +
geom_point(alpha=0.8) +
scale_colour_discrete_divergingx(palette="Zissou 1") +
scale_shape_manual("error", values=c(19, 1)) +
theme_minimal() +
theme(aspect.ratio=1, legend.position="bottom")
sp2 <- ggplot(p_exp_gentoo_proj, aes(x=bm, y=shap1,
colour=species,
shape=factor(1-error))) +
geom_point(alpha=0.8) +
scale_colour_discrete_divergingx(palette="Zissou 1") +
scale_shape_manual("error", values=c(19, 1)) +
ylab("SHAP") +
theme_minimal() +
theme(aspect.ratio=1, legend.position="bottom")
sp2 <- ggplot(p_exp_gentoo_proj, aes(x=shap1,
fill=species, colour=species)) +
geom_density(alpha=0.5) +
geom_vline(xintercept = p_exp_gentoo_proj$shap1[
p_exp_gentoo_proj$species=="Gentoo" &
p_exp_gentoo_proj$error==1], colour="black") +
scale_fill_discrete_divergingx(palette="Zissou 1") +
scale_colour_discrete_divergingx(palette="Zissou 1") +
theme_minimal() +
theme(aspect.ratio=1, legend.position="bottom")
sp2 <- ggplot(p_exp_gentoo_proj, aes(x=bm, y=bd,
colour=species,
shape=factor(1-error))) +
geom_point(alpha=0.8) +
scale_colour_discrete_divergingx(palette="Zissou 1") +
scale_shape_manual("error", values=c(19, 1)) +
theme_minimal() +
theme(aspect.ratio=1, legend.position="bottom")
sp1 + sp2 + plot_layout(ncol=2, guides = "collect") &
theme(legend.position="bottom",
legend.direction="vertical")
```
## Examining boundaries
<!-- Check against LDA, suspect that `bm` is used too much in CNN model.-->
@fig-penguins-lda-nn shows the boundaries for this NN model along with those of the LDA model.
```{r echo=knitr::is_html_output(), eval=FALSE}
#| label: fig-penguins-nn-boundaries
#| code-fold: true
# Generate grid over explanatory variables
p_grid <- tibble(
bl = runif(10000, min(penguins_sub$bl), max(penguins_sub$bl)),
bd = runif(10000, min(penguins_sub$bd), max(penguins_sub$bd)),
fl = runif(10000, min(penguins_sub$fl), max(penguins_sub$fl)),
bm = runif(10000, min(penguins_sub$bm), max(penguins_sub$bm))
)
# Predict grid
p_grid_pred <- p_nn_model %>%
predict(as.matrix(p_grid), verbose=0)
p_grid_pred_cat <- levels(p_train$species)[apply(p_grid_pred, 1, which.max)]
p_grid_pred_cat <- factor(p_grid_pred_cat,
levels=levels(p_train$species))
# Project into weights from the two nodes
p_grid_proj <- as.matrix(p_grid) %*% p_nn_wgts_on
colnames(p_grid_proj) <- c("nn1", "nn2")
p_grid_proj <- p_grid_proj %>%
as_tibble() %>%
mutate(species = p_grid_pred_cat)
# Plot
ggplot(p_grid_proj, aes(x=nn1, y=nn2,
colour=species)) +
geom_point(alpha=0.5) +
geom_point(data=p_all_m, aes(x=nn1,
y=nn2,
shape=species),
inherit.aes = FALSE) +
scale_colour_discrete_divergingx(palette="Zissou 1") +
scale_shape_manual(values=c(1, 2, 3)) +
theme_minimal() +
theme(aspect.ratio=1,
legend.position = "bottom",
legend.title = element_blank())
```
::: {.content-visible when-format="html"}
::: {#fig-penguins-lda-nn-html layout-ncol="2"}
![LDA model](gifs/penguins_lda_boundaries.gif){#fig-lda-boundary fig-alt="FIX ME" width="300"}
![NN model](gifs/penguins_nn_boundaries.gif){#fig-tree-boundary fig-alt="FIX ME" width="300"}
Comparison of the boundaries produced by the LDA (a) and the NN (b) model, using a slice tour.
:::
:::
::: {#fig-penguins-lda-nn layout-ncol="2"}
![LDA model](images/fig-lda-2D-boundaries-1.png){#fig-lda-boundary2 fig-alt="FIX ME" width="200"}
![NN model](images/penguins-nn-boundaries-1.png){#fig-nn-boundary fig-alt="FIX ME" width="290"}
Comparison of the boundaries produced by the LDA (a) and the NN (b) model, using a slice tour.
:::
\index{tour!slice}
## Application to a large dataset
To see how these methods apply in the setting where we have a large number of variables, observations and classes we will look at a neural network that predicts the category for the fashion MNIST data. The code for designing and fitting the model is following the tutorial available from https://tensorflow.rstudio.com/tutorials/keras/classification and you can find additional information there. Below we only replicate the steps needed to build the model from scratch. We also note that a similar investigation was presented in @li2020visualizing, with a focus on investigating the model at different epochs during the training. \index{data!fashion MNIST}
The first step is to download and prepare the data. Here we scale the observations to range between zero and one, and we define the label names.
```{r}
library(keras)
# download the data
fashion_mnist <- dataset_fashion_mnist()
# split into input variables and response
c(train_images, train_labels) %<-% fashion_mnist$train
c(test_images, test_labels) %<-% fashion_mnist$test
# for interpretation we also define the category names
class_names = c('T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot')
# rescaling to the range (0,1)
train_images <- train_images / 255
test_images <- test_images / 255
```
In the next step we define the neural network and train the model. Note that because we have many observations, even a very simple structure returns a good model. And because this example is well-known, we do not need to tune the model or check the validation accuracy.
```{r eval=FALSE}
# defining the model
model_fashion_mnist <- keras_model_sequential()
model_fashion_mnist %>%
# flatten the image data into a long vector
layer_flatten(input_shape = c(28, 28)) %>%
# hidden layer with 128 units
layer_dense(units = 128, activation = 'relu') %>%
# output layer for 10 categories
layer_dense(units = 10, activation = 'softmax')
model_fashion_mnist %>% compile(
optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = c('accuracy')
)
# fitting the model, if we did not know the model yet we
# would add a validation split to diagnose the training
model_fashion_mnist %>% fit(train_images,
train_labels,
epochs = 5)
save_model_tf(model_fashion_mnist, "data/fashion_nn")
```
We have defined a flat neural network with a single hidden layer with 128 nodes. To investigate the model we can start by comparing the activations to the original input data distribution. Since both the input space and the space of activations is large, and they are of different dimensionality, we will first use principal component analysis. This simplifies the analysis, and in general we do not need the original pixel or hidden node information for the interpretation here. The comparison is using the test-subset of the data.
```{r}
# get the fitted model
model_fashion_mnist <- load_model_tf("data/fashion_nn")
# observed response labels in the test set
test_tags <- factor(class_names[test_labels + 1],
levels = class_names)
# calculate activation for the hidden layer, this can be done
# within the keras framework
activations_model_fashion <- keras_model(
inputs = model_fashion_mnist$input,
outputs = model_fashion_mnist$layers[[2]]$output
)
activations_fashion <- predict(
activations_model_fashion,
test_images, verbose = 0)
# PCA for activations
activations_pca <- prcomp(activations_fashion)
activations_pc <- as.data.frame(activations_pca$x)
# PCA on the original data
# we first need to flatten the image input
test_images_flat <- test_images
dim(test_images_flat) <- c(nrow(test_images_flat), 784)
images_pca <- prcomp(as.data.frame(test_images_flat))
images_pc <- as.data.frame(images_pca$x)
```
```{r echo=knitr::is_html_output(), fig.format='png'}
#| code-fold: true
#| code-summary: "Code to run tours"
p2 <- ggplot(activations_pc,
aes(PC1, PC2, color = test_tags)) +
geom_point(size = 0.1) +
ggtitle("Activations") +
scale_color_discrete_qualitative(palette = "Dynamic") +
theme_bw() +
theme(legend.position = "none", aspect.ratio = 1)
p1 <- ggplot(images_pc,
aes(PC1, PC2, color = test_tags)) +
geom_point(size = 0.1) +
ggtitle("Input space") +
scale_color_discrete_qualitative(palette = "Dynamic") +
theme_bw() +
theme(legend.position = "none", aspect.ratio = 1)
legend_labels <- cowplot::get_legend(
p1 +
guides(color = guide_legend(nrow = 1)) +
theme(legend.position = "bottom",
legend.title = element_blank()) +
guides(color = guide_legend(override.aes = list(size = 1)))
)
# hide plotting code
cowplot::plot_grid(cowplot::plot_grid(p1, p2), legend_labels,
rel_heights = c(1, .3), nrow = 2)
```
Looking only at the first two principal components we note some clear differences from the transformation in the hidden layer. The observations seem to be more evenly spread in the input space, while in the activations space we notice grouping along specific directions. In particular the category "Bag" appears to be most different from all other classes, and the non-linear transformation in the activations space shows that they are clearly different from the shoe categories, while in the input space we could note some overlap in the linear projection. To better identify differences between other groups we will use the tour on the first five principal components.
```{r echo=knitr::is_html_output(), eval=FALSE}
#| code-fold: true
#| code-summary: "Code to run tours"
animate_xy(images_pc[,1:5], col = test_tags,
cex=0.2, palette = "Dynamic")
animate_xy(activations_pc[,1:5], col = test_tags,
cex=0.2, palette = "Dynamic")
render_gif(images_pc[,1:5],
grand_tour(),
display_xy(
col=test_tags,
cex=0.2,
palette = "Dynamic",
axes="bottomleft"),
gif_file="gifs/fashion_images_gt.gif",
frames=500,
loop=FALSE
)
render_gif(activations_pc[,1:5],
grand_tour(),
display_xy(
col=test_tags,
cex=0.2,
palette = "Dynamic",
axes="bottomleft"),
gif_file="gifs/fashion_activations_gt.gif",
frames=500,
loop=FALSE
)
```
::: {.content-visible when-format="html"}
::: {#fig-fashion-gt-html layout-ncol="2"}
![Input space](gifs/fashion_images_gt.gif){#fig-fashion-input fig-alt="FIX ME" width="200"}
![Activations](gifs/fashion_activations_gt.gif){#fig-fashion-activation fig-alt="FIX ME" width="200"}
Comparison of the test observations in the first five principal components of the input space (left) and in the hidden layer activations (right). The activation function results in more clearly defined grouping of the different classes.
:::
:::
::: {.content-visible when-format="pdf"}
::: {layout-ncol="2"}
![Input space](images/fashion_images_gt_36.png){fig-alt="FIX ME" width="200"}
![Activations](images/fashion_activation_gt_126.png){fig-alt="FIX ME" width="200"}
Comparison of the test observations in the first five principal components of the input space (left) and in the hidden layer activations (right). The activation function results in more clearly defined grouping of the different classes.
:::
:::
As with the first two principal components we get a much more spread out distribution in the original space. Nevertheless we can see differences between the classes, and that some groups are varying along specific directions in that space. Overall the activations space shows tighter clusters as expected after including the ReLU activation function, but the picture is not as neat as the first two principal components would suggest. While certain groups appear very compact even in this larger subspace, others vary quite a bit within part of the space. For example we can clearly see the "Bag" observations as different from all other images, but also notice that there is a large variation within this class along certain directions.
Finally we will investigate the model performance through the misclassifications and uncertainty between classes. We start with the error matrix for the test observations. To fit the error matrix we use the numeric labels, the ordering is as defined above for the labels.
```{r}
fashion_test_pred <- predict(model_fashion_mnist,
test_images, verbose = 0)
fashion_test_pred_cat <- levels(test_tags)[
apply(fashion_test_pred, 1,
which.max)]
predicted <- factor(
fashion_test_pred_cat,
levels=levels(test_tags)) %>%
as.numeric() - 1
observed <- as.numeric(test_tags) -1
table(observed, predicted)
```
Here the labels are used as 0 - T-shirt/top, 1 - Trouser, 2 - Pullover, 3 - Dress, 4 - Coat, 5 - Sandal, 6 - Shirt, 7 - Sneaker, 8 - Bag, 9 - Ankle boot.
From this we see that the model mainly confuses certain categories with each other, and within expected groups (e.g. different types of shoes can be confused with each other, or different types of shirts). We can further investigate this by visualizing the full probability matrix for the test observations, to see which categories the model is uncertain about.
```{r echo=knitr::is_html_output(), eval=FALSE}
#| code-fold: true
#| code-summary: "Code to visualize probabilities"
# getting the probabilities from the output layer
fashion_test_pred <- predict(model_fashion_mnist,
test_images, verbose = 0)
# this is the same code as was used in the RF chapter
proj <- t(geozoo::f_helmert(10)[-1,])
f_nn_v_p <- as.matrix(fashion_test_pred) %*% proj
colnames(f_nn_v_p) <- c("x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9")
f_nn_v_p <- f_nn_v_p %>%
as.data.frame() %>%
mutate(class = test_tags)
simp <- geozoo::simplex(p=9)
sp <- data.frame(simp$points)
colnames(sp) <- c("x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9")
sp$class = ""
f_nn_v_p_s <- bind_rows(sp, f_nn_v_p) %>%
mutate(class = ifelse(class %in% c("T-shirt/top",
"Pullover",
"Shirt",
"Coat"), class, "Other")) %>%
mutate(class = factor(class, levels=c("T-shirt/top",
"Pullover",
"Shirt",
"Coat",
"Other")))
animate_xy(f_nn_v_p_s[,1:9], col = f_nn_v_p_s$class,
axes = "off", cex=0.2,
edges = as.matrix(simp$edges),
edges.width = 0.05,
palette = "Viridis")
render_gif(f_nn_v_p_s[,1:9],
grand_tour(),
display_xy(
col=f_nn_v_p_s$class,
cex=0.2,
palette = "Viridis",
axes="off",
edges = as.matrix(simp$edges),
edges.width = 0.05),
gif_file="gifs/fashion_confusion_gt.gif",
frames=500,
loop=FALSE
)
```
::: {.content-visible when-format="html"}
::: {#fig-fashion-conf-gt-html}
![Input space](gifs/fashion_confusion_gt.gif){#fig-fashion-confusion fig-alt="FIX ME" width="400"}
A tour of the confusion matrix for the fashion MNIST test observations, focusing on a subset of items. Often observations get confused between two of the classes, this appears as points falling along one of the edges, for example some Shirts look more like T-shirts/tops, while others get confused with Coats. We can also notice that a subset of three other classes not mapped to colors as very separate from this group.
:::
:::
::: {.content-visible when-format="pdf"}
::: {#fig-fashion-confusion-split-pdf layout-ncol="3"}
![](images/fashion_confustion_gt_36.png){fig-alt="FIX ME" width="130"}
![](images/fashion_confusion_gt_58.png){fig-alt="FIX ME" width="130"}
![](images/fashion_confusion_gt_69.png){fig-alt="FIX ME" width="130"}
![](images/fashion_confusion_gt_161.png){fig-alt="FIX ME" width="130"}
![](images/fashion_confusion_gt_212.png){fig-alt="FIX ME" width="130"}
![](images/fashion_confusion_gt_333.png){fig-alt="FIX ME" width="130"}
A tour of the confusion matrix for the fashion MNIST test observations, focusing on a subset of items. Often observations get confused between two of the classes, this appears as points falling along one of the edges, for example some Shirts look more like T-shirts/tops, while others get confused with Coats. We can also notice that a subset of three other classes not mapped to colors as very separate from this group.
:::
:::
The tour of the class probabilities shows that the model is often confused between two classes, this appears as points falling along one edge in the simplex. In particular for the highlighted categories we can notice some interesting patterns, where pairs of classes get confused with each other. We also see some three-way confusions, these are observations that fall on one surface triangle defined via three corners of the simplex, for example between Pullover, Shirt and Coat.
For this data using explainers like SHAP is not so interesting, since the individual pixel contribution to a prediction are typically not of interest. With image classification a next step might be to further investigate which part of the image is important for a prediction, and this can be visualized as a heat map placed over the original image. This is especially interesting in the case of difficult or misclassified images. This however is beyond the scope of this book.
```{=html}
<!--
This chapter will likely include:
- Models at nodes or epochs
- Predictions (like vote matrix) for training and test. It provides a visual guide to overfitting.
- Classification boundaries comparison with other methods
- Explainability and interpretability
(This paper https://distill.pub/2020/grand-tour/ has good examples)
NOTE: Results might vary with different knits
References keras/tensorflow book, tidymodels, interpretable machine learning, and removing the blindfold
Outline for chapter:
- Penguins data
- Setting up with the NN with keras
- Splitting training and test, checking
- Specifying model, choices of layers
- Making predictions
- Extracting weights
- Building your diagnostic data set
- Examine predictive probabilities with simplex
- Examining the nodes - like the discriminant space when its 2
- Misclassifications
-
Bushfires
- Model fitting when regularisation needed
- Overfitting
Sketches as an example
-->
```
## Exercises {.unnumbered}
1. The problem with the NN model fitted to the penguins is that the Gentoo are poorly classified, when they should be perfectly predictable due to the big gap between class clusters. Re-fit the NN to the penguins data, to find a better model that appropriately perfectly predicts Gentoo penguins. Support this by plotting the model (using the hidden layer), and the predictive probabilities as a ternary plot. Do the SHAP values also support that `bd` plays a stronger role in your best model? (`bd` is the main variable for distinguishing Gentoo's from the other species, particularly when used with `fl` or `bl`.)
2. For the fashion MNIST data we have seen that certain categories are more likely to be confused with each other. Select a subset of the data including only the categories Ankle boot, Sneaker and Sandal and see if you can reproduce the analysis of the penguins data in this chapter with this subset.
3. Can you fit a neural network that can predict the class in the fake tree data? Because the data is noisy and we do not have that many observations, it can be easy to overfit the data. Once you find a setting that works, think about what aspects of the model might be interesting for visualization. What comparisons with a random forest model could be of interest?
4. The sketches data could also be considered a classic image classification problem, and we have seen that we can get a reasonable accuracy with a random forest model. Because we only have a smaller number of observations (compared to the fashion MNIST data) when fitting a neural network we need to be very careful not to overfit the training data. Try fitting a flat neural network (similar to what we did for the fashion MNIST data) and check the test accuracy of the model.
5. Challenge: try to design a more accurate neural network for the sketches data. Here you can investigate using a convolutional neural network in combination with data augmentation. In addition, using batch normalization should improve the model performance.
```{r}
#| eval: false
#| echo: false
library(mulgar)
library(dplyr)
data("sketches_train")
load("data/sketches_test_labelled.rda")