forked from hadley/ggplot2-book
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modelling.rmd
435 lines (314 loc) · 21.3 KB
/
modelling.rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
---
output: bookdown::html_chapter
bibliography: references.bib
---
```{r data, include = FALSE}
chapter <- "modelling"
source("common.R")
```
# Modelling for visualisation {#cha:modelling}
## Introduction
Modelling is an essential tool for visualisation. There are two particularly strong connections between modelling and visualisation that I want to explore in this chapter: \index{Modelling}
* Using models as a tool to remove obvious patterns in your plots. This is
useful because strong patterns mask subtler effects. Often the strongest
effects are already known and expected, and removing them allows you to
see surprises more easily.
* Other times you have a lot of data, too much to show on a handful of plots.
Models can be a powerful tool for summarising data so that you get a higher
level view.
In this chapter, I'm going to focus on the use of linear models to acheive these goals. Linear models are a basic, but powerful tool of statistics, and I recommend that everyone serious about visualisation learns at least the basics of how to use them. To this end, I highly recommend two books by Julian J. Faraway:
* Linear Models with R <http://amzn.com/1439887330>
* Extending the Linear Model with R <http://amzn.com/158488424X>
These books cover some of the theory of linear models, but are pragmatic and focussed on how to to actually use linear models (and their extensions) in R. \index{Linear models}
There are many other modelling tools, which I don't have the space to show. If you understand how linear models can help improve your visualisations, you should be able to translate the basic idea to other families of models. This chapter just scratches the surface of what you can do. But hopefully it reinforces how visualisation can play combine with modelling to help you build a powerful data analysis toolbox. For more ideas, check out @model-vis-paper.
This chapter only scratches the surface of the intersection between visualisation and modelling. In my opinion, mastering the combination of visualisations and models is key being an effective data scientist. Unfortunately most books (like this one!) only focus on either visualisation or modelling, but not both. There's a lot of interesting work to be done.
## Removing trend {#sub:trend}
So far our analysis of the diamonds data has been plagued by the powerful relationship between size and price. It makes it very difficult to see the impact of cut, colour and clarity because higher quality diamonds tend to be smaller, and hence cheaper. We can use a linear model to remove the effect of size on price. Instead of looking at the raw price, we can look at the relative price: how valuable is this diamond relative to the average diamond of the same size. \index{Removing trend}
To get started, we'll focus on diamonds of size two carats or less (96% of the dataset). This avoids some incidental problems that you can explore in the exercises, if you're interested. We'll also create two new variables: log price and log carat. These variables are useful because they produce a plot with a strong linear trend.
`r columns(1, 1/2, 0.75)`
```{r}
diamonds2 <- diamonds %>%
filter(carat <= 2) %>%
mutate(
lcarat = log(carat),
lprice = log(price)
)
diamonds2
ggplot(diamonds2, aes(lcarat, lprice)) +
geom_bin2d() +
geom_smooth(method = "lm", se = FALSE, size = 2)
```
In the graphic we used `geom_smooth()` to overlay the line of best fit to the data. We can replicate this outside of ggplot2 by fitting a linear model with `lm()`. This allows us to find out the slope and intercept of the line: \indexf{lm} \indexf{coef}
```{r}
mod <- lm(lprice ~ lcarat, data = diamonds2)
coef(summary(mod))
```
If you're familiar with linear models, you might want to interpret those coefficients: $\log(price) = 8.5 + 1.7 \cdot \log(carat)$, which implies $price = 4900 \cdot carat ^ {1.7}$. Interpreting those coefficients certainly is useful, but even if you don't understand them, the model can still be useful. We can use it to subtract the trend away by looking at the residuals: the price of each diamond minus its predicted price, based on weight alone. Geometrically, the residuals are the vertical distance between each point and the line of best fit. They tell us the price relative to the "average" diamond of that size. \indexf{resid}
```{r}
diamonds2 <- diamonds2 %>% mutate(rel_price = resid(mod))
ggplot(diamonds2, aes(carat, rel_price)) + geom_bin2d()
```
A relative price of zero means that the diamond was at the average price; positive means that it's means that it's more expensive than expected (based on its size), and negative means that it's cheaper than expected.
Interpreting the values precisely is a little tricky here because we've log-transformed price. The residuals give the absolute difference ($x - expected$), but here we have $log(price) - log(expected price)$, or equivalently $log(price / expected price)$. If we "back-transform" to the original scale by applying the opposite transformation (`exp()`) we get $price / expected price$. This makes the values more interpretable, at the cost of the nice symmetry property of the logged alues (i.e. both relatively cheaper and relatively more expensive diamonds have the same range). We can make a little table to help interpret the values:
```{r}
xgrid <- seq(-1.5, 1.5, length = 11)
data.frame(logx = xgrid, x = round(exp(xgrid), 2))
```
For example, a `rel_price` of 0.6 means that it's 1.8x of the expected price (i.e 80% more expensive than expected); a relative price of -0.3 means that it's 0.74x of the expected price (i.e. 26% cheaper than expected). \index{Log!transform}
Let's use both price and relative price to see how color and cut affect the value of a diamond. We'll compute the average price and average relative price for each combination of colour and cut:
```{r}
colour_cut <- diamonds2 %>%
group_by(color, cut) %>%
summarise(
price = mean(price),
rel_price = mean(rel_price)
)
colour_cut
```
If we look at price, it's hard to see how the quality of the diamond affects the price. The lowest quality diamonds (fair cut with colour J) have the highest average value! This is because those diamonds also tend to be larger: size and quality are confounded.
```{r}
ggplot(colour_cut, aes(color, price)) +
geom_line(aes(group = cut), color = "grey80") +
geom_point(aes(colour = cut))
```
If however, we plot the relative price, you see the pattern that you expect: as the quality of the diamonds decreases, the relative price decreases. The worst quality diamond is 0.61x ($exp( -0.49)) the price of an "average" diamond.
```{r}
ggplot(colour_cut, aes(color, rel_price)) +
geom_line(aes(group = cut), color = "grey80") +
geom_point(aes(colour = cut))
```
This technique can be employed in a wide range of situations. Wherever you can explicitly model a strong pattern that you see in a plot, it's worthwhile to use a model to remove that strong pattern so that you can see what interesting trends remain.
### Exercises
1. What happens if you repeat the above analysis with all diamonds? (Not just
all diamonds with two or fewer carats). What does the strange geometry of
`log(carat)` vs relative price represent? What does the diagonal line
without any points represent?
1. I made an unsupported assertion that lower-quality diamonds tend to
be larger. Support my claim with a plot.
1. Can you create a plot that simultaneously shows the effect of colour,
cut, and clarity on relative price? If there's too much information to
show on one plot, think about how you might create a sequence of plots
to convey the same message.
1. How do depth and table relate to the relative price?
## Texas housing data
We'll continue to explore the connection between modelling and visualisation with the `txhousing` dataset:
```{r}
txhousing
```
This data was collected by the Real Estate Center at Texas A&M University, <http://recenter.tamu.edu/Data/hs/>. The data contains information about 46 Texas cities, recording the number of house sales (`sales`), the total volume of sales (`volume`), the `average` and `median` sale prices, the number of houses listed for sale (`listings`) and the number of months inventory (`inventory`). Data is recorded monthly from Jan 2000 to Apr 2015, 187 entries for each city. \index{Data!txhousing@\texttt{txhousing}}
We're going to explore how sales have varied over time for each city as it shows some interesting trends and poses some interesting challenges. Let's start with an overview: a time series of sales for each city: \index{Data!longitudinal}
`r columns(1, 1 / 2, 1)`
```{r}
ggplot(txhousing, aes(date, sales)) +
geom_line(aes(group = city), alpha = 1/2)
```
Two factors make it hard to see the long-term trend in this plot:
1. The range of sales varies over multiple orders of magnitude. The biggest
city, Houston, averages over ~4000 sales per month; the smallest city, San
Marcos, only averages ~20 sales per month.
1. There is a strong seasonal trend: sales are much higher in the summer than
in the winter.
We can fix the first problem by plotting log sales:
```{r}
ggplot(txhousing, aes(date, log(sales))) +
geom_line(aes(group = city), alpha = 1/2)
```
We can fix the second problem using the same technique we used for removing the trend in the diamonds data: we'll fit a linear model and look at the residuals. This time we'll use a categorical predictor to remove the month effect. First we check that the technique works by applying it to a single city. It's always a good idea to start simple so that if something goes wrong you can more easily pinpoint the problem.
`r columns(2, 2 / 3)`
```{r}
abilene <- txhousing %>% filter(city == "Abilene")
ggplot(abilene, aes(date, log(sales))) +
geom_line()
mod <- lm(log(sales) ~ factor(month), data = abilene)
abilene$rel_sales <- resid(mod)
ggplot(abilene, aes(date, rel_sales)) +
geom_line()
```
We can apply this transformation to every city with `group_by()` and `mutate()`. Note the use of `na.action = na.exclude` argument to `lm()`. Counterintuitively this ensures that missing values in the input are matched with missing values in the output predictions and residuals. Without this argument, missing values are just dropped, and the residuals don't line up with the inputs.
```{r}
txhousing <- txhousing %>%
group_by(city) %>%
mutate(rel_sales = resid(lm(log(sales) ~ factor(month), na.action = na.exclude)))
```
With this data in hand, we can re-plot the data. Now that we have log-transformed the data and removed the strong seasonal affects we can see there is a strong common pattern: a consistent increase from 2000-2007, a drop until 2010 (with quite some noise), and then a gradual rebound. To make that more clear, I included a summary line that shows the mean relative sales across all cities.
`r columns(1, 3 / 5, 1)`
```{r}
ggplot(txhousing, aes(date, rel_sales)) +
geom_line(aes(group = city), alpha = 1/5) +
geom_line(stat = "summary", fun.y = "mean", colour = "red")
```
(Note that removing the seasonal effect also removed the intercept - we see the trend for each city relative to its average number of sales.)
### Exercises
1. The final plot shows a lot of short-term noise in the overall trend. How
could you smooth this further to focus on long-term changes?
1. If you look closely (e.g. `+ xlim(2005, 2015)`) at the long-term trend
you'll notice a weird pattern in 2009-2011. It looks like there was a big
dip in 2010. Is this dip "real"? (i.e. can you spot it in the original data)
1. What other variables in the TX housing data show strong
seasonal effects? Does this technique help to remove them?
1. Not all the cities in this data set have complete time series.
Use your dplyr skills to figure out how much data each city
is missing. Display the results with a visualisation.
1. Replicate the computation that `stat_summary()` did with dplyr so
you can plot the data "by hand".
## Visualising models {#sub:modelvis}
The previous examples the linear model just as a tool for removing trend: we fit the model and immediately threw it away. We didn't care about the model itself, just what it could do for us. But the models themselves contain useful information and if we keep them around, there are many new problems that we can solve:
* We might be interested in cities where the model didn't not fit well:
a poorly fitting model suggests that there isn't much of a seasonal pattern,
which contradicts our implicit hypothesis that all cities share a similar
pattern.
* The coefficients themselves might be interesting. In this case, looking
at the coefficients will show us how the seasonal pattern varies between
cities.
* We may want to dive into the details of the model itself, and see exactly
what it says about each observation. For this data, it might help us find
suspicious data points that might reflect data entries errors.
To take advantage of this data, we need to store the models. We can do this using a new dplyr verb: `do()`. It allows us to store the result of arbitrary computation in a column. Here we'll use it to store that linear model: \indexf{do}
```{r}
models <- txhousing %>%
group_by(city) %>%
do(mod = lm(
log(sales) ~ factor(month),
data = .,
na.action = na.exclude
))
models
```
There are two important things to note in this code:
* `do()` creates a new column called `mod.` This is special type of column:
instead of containing an atomic vector (a logical, integer, numeric,
or character) like usual, it's a list. Lists are R's most flexible data
structure and can hold anything, including linear models.
* `.` is a special pronoun used by `do()`. It refers to the "current" data
frame. In this example, `do()` fits the model 46 times (once for each
city), each time replacing `.` with the data for one city. \indexc{.}
If you're an experienced modeller, you might wonder why I didn't fit one model to all cities simultaneously. That's a great next step, but it's often useful to start off simple. Once we have a model that works for each city individually, you can figure out how to generalise it to fit all cities simultaneously.
To visualise these models, we'll turn them into tidy data frames. We'll do that with the __broom__ package by David Robinson. \index{broom} \index{Tidy models} \index{Model data}
```{r}
library(broom)
```
Broom provides three key verbs, each corresponding to one of the challenges outlined above:
* `glance()` extracts __model__-level summaries with one row of data for each
model. It contains summary statistics like the $R^2$ and degrees of freedom.
* `tidy()` extracts __coefficient__-level summaries with one row of data for
each coefficient in each model. It contains information about individual
coefficients like their estimate and standard error.
* `augment()` extracts __observation__-level summaries with one row of data for
each observation in each model. It includes variables like the residual and
influence metrics useful for diagnosing outliers.
We'll learn more about each of these functions in the following three sections.
## Model-level summaries
We'll begin by looking at how well the model fit to each city with `glance()`: \indexf{glance}
```{r}
model_sum <- models %>% glance(mod)
model_sum
```
This creates a variable with one row for each city, and variables that either summarises complexity (e.g. `df`) or fit (e.g. `r.squared`, `p.value`, `AIC`). Since all the models we fit have the same complexity (12 terms: one for each month), we'll focus on the model fit summaries. $R^2$ is a reasonable place to start because it's well known. We can use a dot plot to see the variation across cities:
`r columns(1, 2 / 1)`
```{r}
ggplot(model_sum, aes(r.squared, reorder(city, r.squared))) +
geom_point()
```
It's hard to picture exactly what those values of $R^2$ mean, so it's helpful to pick out a few exemplars. The following code extracts and plots out the three cities with the highest and lowest $R^2$:
`r columns(1, 1)`
```{r}
top3 <- c("Midland", "Irving", "Denton County")
bottom3 <- c("Brownsville", "Harlingen", "McAllen")
extreme <- txhousing %>% ungroup() %>%
filter(city %in% c(top3, bottom3), !is.na(sales)) %>%
mutate(city = factor(city, c(top3, bottom3)))
ggplot(extreme, aes(month, log(sales))) +
geom_line(aes(group = year)) +
facet_wrap(~city)
```
The cities with low $R^2$ have weaker seasonal patterns and more variation between years. The data for Harlingen seems particularly noisy.
### Exercises
1. Do your conclusions change if you use a different measurement of model fit?
Why/why not?
1. One possible hypothesis that explains why McAllen, Harlingen and Brownsville
have lower $R^2$ is that they're smaller towns so there are fewer sales and
more noise. Confirm or refute this hypothesis.
1. McAllen, Harlingen and Brownsville seem to have much more year-to-year
variation than Midland, Irving and Denton County. How does the model
change if you also include a linear trend for year? (i.e.
`log(sales) ~ factor(month) + year`).
1. Create a faceted plot that shows the seasonal patterns for all cities.
Order the facets by the $R^2$ for the city.
## Coefficient-level summaries
The model fit summaries suggest that there are some important differences in seasonality between the different cities. Lets dive into those differences by using `tidy()` to extract detail about each individual coefficient: \indexf{tidy}
```{r}
coefs <- models %>% tidy(mod)
coefs
```
We're more interested in the month effect, so we'll do a little extra tidying to only look at the month coefficients, and then to extract the month value into a numeric variable:
```{r}
months <- coefs %>%
filter(grepl("factor", term)) %>%
tidyr::extract(term, "month", "(\\d+)", convert = TRUE)
months
```
This is a common pattern. You need to use your data tidying skills at many points in an analysis. Once you have the correct tidy dataset, creating the plot is usually easy. Here we'll put month on the x axis, estimate on the y-axis, and draw one line for each city. I've back-transformed (exponentiated) to make the coefficients more interpretable: these are now ratios of sales compared to January.
`r columns(1, 2/3)`
```{r}
ggplot(months, aes(month, exp(estimate))) +
geom_line(aes(group = city))
```
The pattern seems similar across the cities. The main difference is the strength of the seasonal effect. Let's pull that out and plot it:
`r columns(1, 3/2)`
```{r}
coef_sum <- months %>%
group_by(city) %>%
summarise(max = max(estimate))
ggplot(coef_sum, aes(max, reorder(city, max))) +
geom_point()
```
The cities with the strongest seasonal effect are College Station and San Marcos (both college towns) and Galveston and South Padre Island (beach cities). It makes sense that these cities would have very strong seasonal effects.
### Exercises
1. Pull out the three cities with highest and lowest seasonal effect. Plot
their coefficients.
1. How does strength of seasonal effect relation to the $R^2$ for the model?
Answer with a plot.
```{r, echo = FALSE, eval = FALSE}
coef_sum %>%
left_join(model_sum %>% select(city, r.squared)) %>%
ggplot(aes(max, r.squared)) +
geom_point() +
geom_smooth(method = "lm", se = F)
```
1. You should be extra cautious when your results agree with your prior
beliefs. How can you confirm or refute my hypothesis about the causes
of strong seasonal patterns?
1. Group the diamonds data by cut, clarity and colour. Fit a linear model
`log(price) ~ log(carat)`. What does the intercept tell you? What does
the slope tell you? How do the slope and intercept vary across the
groups? Answer with a plot.
## Observation data
Observation level data, which include residual diagnostics, is most useful in the traditional model fitting scenario. It's less useful in conjunction with visualisation, but we can still find interesting patterns. It's particularly useful if we want to track down individual values that seem particularly odd. Extracting observation level data is the job of the `augment()` function. This adds one row for each observation. It includes the variables used in the original model, and a number of common influence statistics (see `?augment.lm` for more details): \indexf{augment}
```{r, warning=FALSE}
obs_sum <- models %>% augment(mod)
obs_sum
```
For example, it might be interesting to look at the distribution of standardised residuals. (These are residuals standardised to have a variance of one in each model, making them more comparable). We're looking for unusual values that might need deeper exploration:
`r columns(2, 2/3)`
```{r}
ggplot(obs_sum, aes(.std.resid)) +
geom_histogram(binwidth = 0.1)
ggplot(obs_sum, aes(abs(.std.resid))) +
geom_histogram(binwidth = 0.1)
```
A threshold of 2 seems like a reasonable threshold to explore individually:
```{r}
obs_sum %>%
filter(abs(.std.resid) > 2) %>%
group_by(city) %>%
summarise(n = n(), avg = mean(abs(.std.resid))) %>%
arrange(desc(n))
```
In a real analysis, you'd want to look into these cities in more detail.
### Exercises
1. A common diagnotic plot is fitted values (`.fitted`) vs. residuals
(`.resid`). Do you see any patterns? What if you include the city
or month on the same plot?
1. Create a time series of log(sales) for each city. Highlight points that have
a standardised residual of greater than 2.
## References