Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make augment.merMod() more consistent with predict.merMod() when using newdata #141

Open
mccarthy-m-g opened this issue May 30, 2023 · 2 comments
Labels
infelicity less than a bug, but worth fixing

Comments

@mccarthy-m-g
Copy link

This is related to #125, but I felt like it deserved its own issue.

The behaviour (and documentation) of augment.merMod() when making predictions on new data could use some love. The current behaviour is inconsistent with predict.merMod() and leads to unexpected results that can be misleading or unclear.

Here's a reprex covering some of the issue's I found. I think the function needs a rewrite to handle augmenting the original data used to fit the model differently from making predictions on new data. Perhaps dropping its dependence on broom::augment_columns() given some of its behaviour (or at least adding some error checking for the cases where it should fail).

Regarding documentation, it isn't documented anywhere that you can use the re.form argument with augment.merMod()/broom::augment_columns(); I tried it on a whim while trying to make predictions and it just happened to (partially) work.

library(tibble)
library(lme4)
#> Loading required package: Matrix
library(broom)
library(broom.mixed)

lmm1 <- lmer(Reaction ~ Days + (Days | Subject), sleepstudy)

# When you just want to augment the data used to fit the model everything is
# good and the results are what you'd expect. However, things go wrong once you
# want to make predictions with new data.
augment(lmm1)
#> # A tibble: 180 × 14
#>    Reaction  Days Subject .fitted  .resid   .hat .cooksd .fixed   .mu .offset
#>       <dbl> <dbl> <fct>     <dbl>   <dbl>  <dbl>   <dbl>  <dbl> <dbl>   <dbl>
#>  1     250.     0 308        254.   -4.10 0.229  0.00496   251.  254.       0
#>  2     259.     1 308        273.  -14.6  0.170  0.0402    262.  273.       0
#>  3     251.     2 308        293.  -42.2  0.127  0.226     272.  293.       0
#>  4     321.     3 308        313.    8.78 0.101  0.00731   283.  313.       0
#>  5     357.     4 308        332.   24.5  0.0910 0.0506    293.  332.       0
#>  6     415.     5 308        352.   62.7  0.0981 0.362     304.  352.       0
#>  7     382.     6 308        372.   10.5  0.122  0.0134    314.  372.       0
#>  8     290.     7 308        391. -101.   0.162  1.81      325.  391.       0
#>  9     431.     8 308        411.   19.6  0.219  0.106     335.  411.       0
#> 10     466.     9 308        431.   35.7  0.293  0.571     346.  431.       0
#> # … with 170 more rows, and 4 more variables: .sqrtXwt <dbl>, .sqrtrwt <dbl>,
#> #   .weights <dbl>, .wtres <dbl>

# For context, first let's cover predict.merMod()'s behaviour. ----------------

# If you want to make predictions conditioned on random effects, you need to
# provide data for the random effects groups you want to make predictions for.
# For example, here we make predictions for Subjects 308 and 310 on Days 0-3.
# The resulting vector is the same length as `newdata`.
predict(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#>        1        2        3        4        5        6        7        8 
#> 253.6637 273.3299 292.9962 312.6624 212.4447 217.4631 222.4816 227.5000

# If we don't provide data for any Subjects, we get an error. This is expected
# and fine, but as we'll see later, augment.merMod() ignores this convention.
predict(lmm1, newdata = tibble(Days = 0:3))
#> Error in eval(predvars, data, env): object 'Subject' not found

# If we don't want to condition on the random effects, and instead want fixed
# effect predictions, we need to be explicit about that with `re.form = NA`.
# Here too the resulting vector is the same length as `newdata`.
predict(lmm1, newdata = tibble(Days = 0:3), re.form = NA)
#>        1        2        3        4 
#> 251.4051 261.8724 272.3397 282.8070

# For more context, next let's cover augment_columns()'s behaviour ------------

# augment_columns() is a developer-facing function intended for use in the
# internals of augment methods. It is used as a starting point in
# augment.merMod(), then wrangled further later in the function. The wrangling
# causes issues later on.

# augment_columns() has consistent behaviour with predict.merMod() in some but
# not all cases. Here I've simply repeated the same three predict() examples
# from above.

# Consistent with predict.merMod(), results are what you'd expect.
augment_columns(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> # A tibble: 8 × 3
#>    Days Subject .fitted
#>   <int>   <dbl>   <dbl>
#> 1     0     308    254.
#> 2     1     308    273.
#> 3     2     308    293.
#> 4     3     308    313.
#> 5     0     310    212.
#> 6     1     310    217.
#> 7     2     310    222.
#> 8     3     310    227.

# No error thrown this time, even though no Subject data was provided. This
# appears to be the predictions for all subjects, with the Days vector recycled
# to the total number of observations. There is no mention of this, with no
# Subject column to cross-reference against, and the Days column now has no
# correspondence to the .fitted column.
augment_columns(lmm1, newdata = tibble(Days = 0:3))
#> # A tibble: 180 × 2
#>     Days .fitted
#>    <int>   <dbl>
#>  1     0    254.
#>  2     1    273.
#>  3     2    293.
#>  4     3    313.
#>  5     0    332.
#>  6     1    352.
#>  7     2    372.
#>  8     3    391.
#>  9     0    411.
#> 10     1    431.
#> # … with 170 more rows
predict(lmm1, newdata = tibble(Days = 0:9, Subject = 308))
#>        1        2        3        4        5        6        7        8 
#> 253.6637 273.3299 292.9962 312.6624 332.3287 351.9950 371.6612 391.3275 
#>        9       10 
#> 410.9937 430.6600

# It also doesn't matter if you're explicit about the `re.form`; it still doesn't
# throw an error.
augment_columns(lmm1, newdata = tibble(Days = 0:3), re.form = ~ (Days | Subject))
#> # A tibble: 180 × 2
#>     Days .fitted
#>    <int>   <dbl>
#>  1     0    254.
#>  2     1    273.
#>  3     2    293.
#>  4     3    313.
#>  5     0    332.
#>  6     1    352.
#>  7     2    372.
#>  8     3    391.
#>  9     0    411.
#> 10     1    431.
#> # … with 170 more rows

# Consistent with predict.merMod(), results are what you'd expect.
augment_columns(lmm1, newdata = data.frame(Days = 0:3), re.form = NA)
#> # A tibble: 4 × 2
#>    Days .fitted
#>   <int>   <dbl>
#> 1     0    251.
#> 2     1    262.
#> 3     2    272.
#> 4     3    283.

# Now let's look at augment.merMod()'s behaviour ------------------------------

## Predictions conditioned on random effects:

# This throws a warning due to some of the aforementioned wrangling that happens
# inside augment.merMod() after getting the data from augment_columns().
# Specifically, the `respCols` (.mu, .offset, etc.) that are getting bound to
# the augment_columns() data frame come from the original data used to fit the
# model, rather than the new data.
augment(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> Warning in indices[which(stats::complete.cases(original))] <- seq_len(nrow(x)):
#> number of items to replace is not a multiple of replacement length
#> # A tibble: 8 × 9
#>    Days Subject .fitted   .mu .offset .sqrtXwt .sqrtrwt .weights  .wtres
#>   <int>   <dbl>   <dbl> <dbl>   <dbl>    <dbl>    <dbl>    <dbl>   <dbl>
#> 1     0     308    254.  254.       0        1        1        1   -4.10
#> 2     1     308    273.  273.       0        1        1        1  -14.6 
#> 3     2     308    293.  293.       0        1        1        1  -42.2 
#> 4     3     308    313.  313.       0        1        1        1    8.78
#> 5     0     310    212.  332.       0        1        1        1   24.5 
#> 6     1     310    217.  352.       0        1        1        1   62.7 
#> 7     2     310    222.  372.       0        1        1        1   10.5 
#> 8     3     310    227.  391.       0        1        1        1 -101.

# As a consequence, the respCols don't actually correspond to the new data. For
# example, this is clear if you look at the .mu and .wtres columns above. All
# the values come from the first 8 values in the model, which in this cases
# means they actually all come from Subject 308. This is bad.
lmm1@resp$mu[1:8]
#> [1] 253.6637 273.3299 292.9962 312.6624 332.3287 351.9950 371.6612 391.3275
lmm1@resp$wtres[1:8]
#> [1]   -4.103656  -14.625218  -42.195579    8.777359   24.523197   62.695136
#> [7]   10.542574 -101.178888

## Leaving Subject out of newdata:

# This has the same problems as augment_columns(). The respCols at least
# correspond to .fitted and .fixed now, but this should really throw an error
# instead.
augment(lmm1, newdata = tibble(Days = 0:3))
#> # A tibble: 180 × 9
#>     Days .fitted .fixed   .mu .offset .sqrtXwt .sqrtrwt .weights  .wtres
#>    <int>   <dbl>  <dbl> <dbl>   <dbl>    <dbl>    <dbl>    <dbl>   <dbl>
#>  1     0    254.   251.  254.       0        1        1        1   -4.10
#>  2     1    273.   262.  273.       0        1        1        1  -14.6 
#>  3     2    293.   272.  293.       0        1        1        1  -42.2 
#>  4     3    313.   283.  313.       0        1        1        1    8.78
#>  5     0    332.   293.  332.       0        1        1        1   24.5 
#>  6     1    352.   304.  352.       0        1        1        1   62.7 
#>  7     2    372.   314.  372.       0        1        1        1   10.5 
#>  8     3    391.   325.  391.       0        1        1        1 -101.  
#>  9     0    411.   335.  411.       0        1        1        1   19.6 
#> 10     1    431.   346.  431.       0        1        1        1   35.7 
#> # … with 170 more rows

## Fixed effect predictions:

# Similar to the problem with the predictions conditioned on random effects,
# the `respCols` have no correspondence to the new data. This is obvious if
# you make predictions on different days.
augment(lmm1, newdata = data.frame(Days = 0:3), re.form = NA)
#> Warning in indices[which(stats::complete.cases(original))] <- seq_len(nrow(x)):
#> number of items to replace is not a multiple of replacement length
#> # A tibble: 4 × 8
#>    Days .fitted   .mu .offset .sqrtXwt .sqrtrwt .weights .wtres
#>   <int>   <dbl> <dbl>   <dbl>    <dbl>    <dbl>    <dbl>  <dbl>
#> 1     0    251.  254.       0        1        1        1  -4.10
#> 2     1    262.  273.       0        1        1        1 -14.6 
#> 3     2    272.  293.       0        1        1        1 -42.2 
#> 4     3    283.  313.       0        1        1        1   8.78
augment(lmm1, newdata = data.frame(Days = 6:9), re.form = NA)
#> Warning in indices[which(stats::complete.cases(original))] <- seq_len(nrow(x)):
#> number of items to replace is not a multiple of replacement length
#> # A tibble: 4 × 8
#>    Days .fitted   .mu .offset .sqrtXwt .sqrtrwt .weights .wtres
#>   <int>   <dbl> <dbl>   <dbl>    <dbl>    <dbl>    <dbl>  <dbl>
#> 1     6    314.  254.       0        1        1        1  -4.10
#> 2     7    325.  273.       0        1        1        1 -14.6 
#> 3     8    335.  293.       0        1        1        1 -42.2 
#> 4     9    346.  313.       0        1        1        1   8.78

Created on 2023-05-30 with reprex v2.0.2

@bbolker
Copy link
Owner

bbolker commented May 30, 2023

Thanks! Given the number of open issues and their heterogeneity I really think I need to go through and tag them with 'feature-request'/'enhancement' etc. so I can prioritize them and fix the ones that really need to be fixed (I would put this one in that category ... maybe I'll add an 'infelicity' tag [that's Bill Venables's neutral term for "it's not technically a bug but it's definitely bad behaviour"] that's just below 'bug' in priority ...)

@mccarthy-m-g
Copy link
Author

Welcome! Haha, I like the infelicity tag. It's probably also worth checking if this issue applies to any of the other functions that rely on augment_columns().

Here's a very rough rewrite of augment.merMod(), basically just adding some conditionals to the existing code. It might be as simple as something like this.

library(tibble)
library(lme4)
#> Loading required package: Matrix
library(broom)
library(broom.mixed)

lmm1 <- lmer(Reaction ~ Days + (Days | Subject), sleepstudy)

augment.merMod <- function(x, data = stats::model.frame(x), newdata, ...) {
  # Augment the original data used to fit the model
  if (missing(newdata)) {
    # move rownames if necessary
    newdata <- NULL
    ret <- suppressMessages(augment_columns(x, data, newdata, se.fit = NULL, ...))

    # add predictions with no random effects (population means)
    predictions <- stats::predict(x, re.form = NA)
    # some cases, such as values returned from nlmer, return more than one
    # prediction per observation. Not clear how those cases would be tidied
    if (length(predictions) == nrow(ret)) {
      ret$.fixed <- predictions
    }

    # columns to extract from resp reference object
    # these include relevant ones that could be present in lmResp, glmResp,
    # or nlsResp objects

    respCols <- c(
      "mu", "offset", "sqrtXwt", "sqrtrwt", "weights",
      "wtres", "gam", "eta"
    )
    cols <- lapply(respCols, function(cc) x@resp[[cc]])
    names(cols) <- paste0(".", respCols)

    ## remove too-long fields and empty fields
    n_vals <- vapply(cols,length,1L)
    min_n <- min(n_vals[n_vals>0])

    cols <- dplyr::bind_cols(cols[n_vals==min_n])

    cols <- broom.mixed:::insert_NAs(cols, ret)
    if (length(cols) > 0) {
      ret <- dplyr::bind_cols(ret, cols)
    }

    return(broom.mixed:::unrowname(ret))

  # Make predictions on new data
  } else {
    ret <- suppressMessages(augment_columns(x, data, newdata, se.fit = NULL, ...))

    # Throw an error when re.form isn't specified, and there's no grouping
    # variable in newdata. This is fragile but just intended for demonstration.
    # Note: Can't use missing() since re.form comes from the ... args.
    if (!hasArg(re.form) & ncol(stats::model.frame(x)) != ncol(ret)) {
      stop("No data provided for grouping variable.")
    }

    # add predictions on newdata with no random effects (population means)
    predictions <- stats::predict(x, newdata, re.form = NA)
    # some cases, such as values returned from nlmer, return more than one
    # prediction per observation. Not clear how those cases would be tidied
    if (length(predictions) == nrow(ret)) {
      ret$.fixed <- predictions
    }

    tibble::tibble(ret, .mu = NA, .offset = NA, etc. = NA)

  }
}

augment(lmm1)
#> # A tibble: 180 × 14
#>    Reaction  Days Subject .fitted  .resid   .hat .cooksd .fixed   .mu .offset
#>       <dbl> <dbl> <fct>     <dbl>   <dbl>  <dbl>   <dbl>  <dbl> <dbl>   <dbl>
#>  1     250.     0 308        254.   -4.10 0.229  0.00496   251.  254.       0
#>  2     259.     1 308        273.  -14.6  0.170  0.0402    262.  273.       0
#>  3     251.     2 308        293.  -42.2  0.127  0.226     272.  293.       0
#>  4     321.     3 308        313.    8.78 0.101  0.00731   283.  313.       0
#>  5     357.     4 308        332.   24.5  0.0910 0.0506    293.  332.       0
#>  6     415.     5 308        352.   62.7  0.0981 0.362     304.  352.       0
#>  7     382.     6 308        372.   10.5  0.122  0.0134    314.  372.       0
#>  8     290.     7 308        391. -101.   0.162  1.81      325.  391.       0
#>  9     431.     8 308        411.   19.6  0.219  0.106     335.  411.       0
#> 10     466.     9 308        431.   35.7  0.293  0.571     346.  431.       0
#> # … with 170 more rows, and 4 more variables: .sqrtXwt <dbl>, .sqrtrwt <dbl>,
#> #   .weights <dbl>, .wtres <dbl>
augment(lmm1, newdata = expand.grid(Days = 0:3, Subject = c(308, 310)))
#> # A tibble: 8 × 7
#>    Days Subject .fitted .fixed .mu   .offset etc. 
#>   <int>   <dbl>   <dbl>  <dbl> <lgl> <lgl>   <lgl>
#> 1     0     308    254.   251. NA    NA      NA   
#> 2     1     308    273.   262. NA    NA      NA   
#> 3     2     308    293.   272. NA    NA      NA   
#> 4     3     308    313.   283. NA    NA      NA   
#> 5     0     310    212.   251. NA    NA      NA   
#> 6     1     310    217.   262. NA    NA      NA   
#> 7     2     310    222.   272. NA    NA      NA   
#> 8     3     310    227.   283. NA    NA      NA
augment(lmm1, newdata = tibble(Days = 0:3))
#> Error in augment.merMod(lmm1, newdata = tibble(Days = 0:3)): No data provided for grouping variable.
augment(lmm1, newdata = data.frame(Days = 0:3), re.form = NA)
#> # A tibble: 4 × 6
#>    Days .fitted .fixed .mu   .offset etc. 
#>   <int>   <dbl>  <dbl> <lgl> <lgl>   <lgl>
#> 1     0    251.   251. NA    NA      NA   
#> 2     1    262.   262. NA    NA      NA   
#> 3     2    272.   272. NA    NA      NA   
#> 4     3    283.   283. NA    NA      NA
augment(lmm1, newdata = data.frame(Days = 6:9), re.form = NA)
#> # A tibble: 4 × 6
#>    Days .fitted .fixed .mu   .offset etc. 
#>   <int>   <dbl>  <dbl> <lgl> <lgl>   <lgl>
#> 1     6    314.   314. NA    NA      NA   
#> 2     7    325.   325. NA    NA      NA   
#> 3     8    335.   335. NA    NA      NA   
#> 4     9    346.   346. NA    NA      NA

Created on 2023-05-30 with reprex v2.0.2

@bbolker bbolker added the infelicity less than a bug, but worth fixing label May 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
infelicity less than a bug, but worth fixing
Projects
None yet
Development

No branches or pull requests

2 participants