Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exact polytime solution for linear regression #130

Closed
dswatson opened this issue Apr 18, 2024 · 9 comments
Closed

Exact polytime solution for linear regression #130

dswatson opened this issue Apr 18, 2024 · 9 comments
Labels
enhancement New feature or request

Comments

@dswatson
Copy link

I may be wrong but it looks like kernelshap takes a sampling approach even in the case of linear regression. There's a simple, closed form solution in this case, at least under the assumption that features are independent (which is always the case in kernelshap, no?) See Eq. 3 of https://arxiv.org/abs/1903.10464.

@mayer79
Copy link
Collaborator

mayer79 commented Apr 18, 2024

Yes. But since linear regressions typically involve splines/polynomials and interactions, things might get tricky, at least when explaining the original features

@dswatson
Copy link
Author

I see. So this would apply in cases where the lm function itself encodes those transformations, correct? Like lm(y ~ I(x^2) + x*z) or something? Perhaps there could be an internal check for such operations so kernelshap knows whether to apply the closed form solution, potentially with a warning?

@mayer79
Copy link
Collaborator

mayer79 commented Apr 18, 2024

Exactly. We could actually think of a flag additive=TRUE, also in the new permshap() function. Or simply add a additiveshap() explainer.

@mayer79 mayer79 added the enhancement New feature or request label Apr 18, 2024
@mayer79
Copy link
Collaborator

mayer79 commented May 22, 2024

I see. So this would apply in cases where the lm function itself encodes those transformations, correct? Like lm(y ~ I(x^2) + x*z) or something? Perhaps there could be an internal check for such operations so kernelshap knows whether to apply the closed form solution, potentially with a warning?

I was thinking a bit how to implement additiveshap(X, bg_X, weights=NULL). One option is using predict(..., type="terms"). It is the approach that {fastshap} uses for its exact linear explainer (via lm()).

library(mgcv)
library(splines)

fit <- lm(Sepal.Width ~ Species + ns(Sepal.Length, df = 5), data=iris)
predict(fit, newdata = head(iris), type = "terms")

#     Species ns(Sepal.Length, df = 5)
# 1 0.7462667              -0.29796891
# 2 0.7462667              -0.46501987
# 3 0.7462667              -0.61525943
# 4 0.7462667              -0.68560767
# 5 0.7462667              -0.38366037
# 6 0.7462667              -0.05111989

fit <- gam(Sepal.Width ~ Species + s(Sepal.Length), data=iris)
predict(fit, newdata = head(iris), type = "terms")

#   Species s(Sepal.Length)
# 1       0      -0.2921379
# 2       0      -0.4304525
# 3       0      -0.5733386
# 4       0      -0.6452925
# 5       0      -0.3604078
# 6       0      -0.1055424


fit <- lm(Sepal.Width ~ Species * ns(Sepal.Length, df = 5), data=iris)
predict(fit, newdata = head(iris), type = "terms")

#   Species ns(Sepal.Length, df = 5) Species:ns(Sepal.Length, df = 5)
# 1 2.27277                 2.593177                        -4.437103
# 2 2.27277                 2.430295                        -4.437103
# 3 2.27277                 2.285862                        -4.437103
# 4 2.27277                 2.218714                        -4.437103
# 5 2.27277                 2.509137                        -4.437103
# 6 2.27277                 2.870235                        -4.437103

This is quite neat because we don't even need to collapse values from the Species dummies. Interactions would pop out as separate columns.

Things that are not (yet) clear to me:

  1. Your reference (and others I have in mind) clearly state that the SHAP values are equal to $(X-\bar X) \beta$. Via terms, we use non-centered X. I think centering is important for calculating SHAP importance.
  2. Related to 1 is also the calculation of a meaningful baseline value (relevant only for waterfall-like charts). Without centering, it equals the intercept.

Can we use some "post-hoc" centering of the SHAP values, like e.g. subtracting from the terms $X\beta$ its average calculated from the background data?

@dswatson
Copy link
Author

Cool! Did not know about the type = "terms" option. It looks like that's just exactly what we want, right? The function appears to do mean centering by default (see here). This also gives you the baseline phi0 for free. The only thing that remains would be some check/warning about assuming independence when a single input variable is used to define multiple columns in the design matrix, e.g. due to interactions or polynomial expansion.

@mayer79
Copy link
Collaborator

mayer79 commented May 23, 2024

Neat - was not aware of the centering!

Regarding your check: I am actually quite confident that we don't need to explicitly deal with this:

  • If the user makes use of poly(), ns(), dummy encoding etc., then we get a single column for that feature. Perfect.
  • If the user adds interactions, then the output will contain 1 extra column for each interaction, which seems acceptable, given that the name of the function is additive_shap().

The API could be:

additive_shap(object, X = NULL)

And the output like the one of kernelshap::permshap().

@dswatson
Copy link
Author

Sounds good! I thought your original concern was that, while the independence assumption is probably false in many real world cases, it is self-evidently bogus when one variable is a deterministic function of another (e.g., a polynomial fit with features x and x^2). Then again, hopefully users are aware of this.

@mayer79
Copy link
Collaborator

mayer79 commented May 23, 2024

Yeah...

Actually, there is a technical complication: Take, e.g., the model

y ~ log(x) + log(z / x)

In this case, "terms" contains the two columns log(x) and log(z / x), but there is no obvious mapping to the columns x and z in the data matrix.

In such cases, we could ask the user to provide a 1:p mapping between names of X and names of "terms", e.g., mapping = list(x = 'log(x)', z = 'log(z / x)). Or: mapping = list(x = c('log(x)', 'log(y / x)').

For interactions, I think such a mapping cannot make sense. We can raise an error there. Here is a draft:

additive_shap <- function(object, X, verbose = TRUE, mapping = NULL) {
  stopifnot(inherits(object, c("gam", "glm", "lm")))
  
  tt <- terms(object)
  if (any(attr(tt, "order") > 1)) {
    stop("Additive SHAP not appropriate for models with interactions.")
  }

  txt <- "Exact additive SHAP via predict(..., type = 'terms')"
  if (verbose) {
    message(txt)
  }
  
  # Inspired by fastshap:::explain.lm(..., exact = TRUE)
  S <- stats::predict(object, newdata = X, type = "terms")
  
  # We need a 1:p mapping from columns in X to columns in "terms"
  if (is.null(mapping)) {
    xcols <- all.vars(formula(object), unique = FALSE)[-1L]
    if (anyDuplicated(xcols) || length(xcols) != ncol(S)) {
      stop(
        "Unclear 1:m mapping between columns in X and columns in 'terms'.
        Pass explicit mapping as list(name_in_x = c('name_in_terms_1', 'name_in_terms_2'), ...)"
      )
    }
    colnames(S) <- xcols
  } else {
    S <- map_shap(S, X, mapping = mapping)  # TBD
  }
  
  baseline <- as.vector(attr(S, "constant"))
  
  out <- list(S = S, X = X, baseline = as.vector(baseline), txt = txt)
  class(out) <- "additive_shap"
  return(out)
}  

Edit

Working on a better solution...

@mayer79 mayer79 mentioned this issue May 25, 2024
@mayer79
Copy link
Collaborator

mayer79 commented May 26, 2024

Implemented via #132

@mayer79 mayer79 closed this as completed May 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants