Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TreeSHAP for conditional expectations of variable subsets #262

Open
samkodes opened this issue Jun 16, 2021 · 11 comments
Open

TreeSHAP for conditional expectations of variable subsets #262

samkodes opened this issue Jun 16, 2021 · 11 comments

Comments

@samkodes
Copy link

samkodes commented Jun 16, 2021

Hi all,

I am intrigued by the experimental result of Aas, Jullum, and Loland (https://arxiv.org/abs/1903.10464) that TreeSHAP fails to capture covariate dependence in any meaningful way. Do you have any insight into why this may be?

I ask because the conditional dependence estimation procedures shapr implements, particularly the empirical method, seem very similar to the adaptive nearest-neighbour interpretation of random forests, e.g. the causal forests used by the grf package (Athey and Wager https://arxiv.org/abs/1902.07409). A TreeSHAP-like algorithm might be an effective way of calculating the conditional expectation for a subset of variables using the adaptive neighbourhoods already learned by an underlying random forest model.

I wonder if part of the reason TreeSHAP failed the tests is that it was run on a boosted model rather than a random forest (as far as I know, boosted trees don't have a nearest-neighbour interpretation). Would this be worth investigating further?

Update:

The intuition might be that because of the scale reduction in a boosted tree ensemble, removal of some covariates tends to make the resulting expectations rather unpredictable (mostly dependant on the high-variance initial large-scale trees), whereas the redundancy of a bagged tree ensemble means that other the remaining covariates may still informatively partition the space.

If TreeSHAP on random forests does work for estimating conditional expectations, then it might be a viable option to be built into the package for non-forest underlying models. The focus would not be on estimating p(S' | S) (distribution of missing variables conditional on present variables) but instead on estimating the required conditional expectation directly - E(f(S',S)|S).

I would be happy to run some tests with random forests using the R treeshap package on the test data sets used in the paper if you could provide them?

@samkodes
Copy link
Author

samkodes commented Jun 18, 2021

Early attempts at reproducing the stepwise constant experiments from the paper with Gaussian data show some improvement for RF vs XGB, and, oddly enough, "extremely random" forests (ranger split criterion = extratrees) seem to have an edge over both. Doesn't seem to get as good as the gaussian shapr option, however.

There is also this very interesting paper on "SHAFF" (https://arxiv.org/abs/2105.11724), which appears to have theoretical results for a method that replaces TreeSHAP with a "Projected Forest" transformation of an underlying random forest. The main difference vs TreeSHAP is essentially that intersections rather than unions of nodes are used (with a depth limit trick to ensure non-zero nodes). This fits with some of my intuition about the dependence of the method on the way the tree partitions space, and the diagram makes it clear in retrospect that intersection makes much more sense than union if an asymptotic guarantee is desired. OOOF! There are some smart people out there!

To clarify: The SHAFF approach uses intersections of leaf regions projected into the selected variable subspace. Their diagram suggests intuition for why this works. Though the authors do not discuss TreeSHAP, in their terms TreeSHAP uses unions of projected leaf regions; if some leaves are elongated in the selected variable subspace, this will result in some very distant data points being included in the conditional distribution. Conversely, the use of intersections will tend to ensure that all data points used for the conditional distribution are local. This intuition explains why some algorithms for tree growth will work better than others for TreeSHAP - any tree-growing algorithm that produces lots of leaves that are elongated in the selected variable subspace will do poorly at capturing the conditional distribution. Extremely random forests may do better than regular random forests because they may be less likely to have very elongated leaf nodes.

@samkodes
Copy link
Author

samkodes commented Jun 20, 2021

Here are the experimental results. I reproduce the stepwise-constant function in the shapr paper and run a 7-dimensional version of the 10d problem. Monte-carlo integration to get baseline shap. 400 trees in each model, default parameters, for extratrees model a single variable only (i.e. fully random trees). It turns out that the apparent improvements of TreeSHAP on the RF and ET (extra-trees) are mostly due to variable 7, the variable that does not appear in the stepwise-constant function, though for high values of rho (off-diagonal correlation) RF edges ahead of XGB on the other variables as well.
In the first figure below I facet by Shapley algorithm and by variable, and colour by model type. In the second figure below I facet by model type and variable, and colour by Shapley algorithm.
Rho values tested are {0, 0.1, 0.2, 0.3, 0.5, 0.8}
SHAFF experiments to follow hopefully in a day or two.
comp_mod_type
comp_shap_method

@martinju
Copy link
Member

Hi! Sorry for the late response here. This is very interesting! Some comments below:

  1. To answer your for question about the intuition of why TreeSHAP does not work well I don't have much to offer just now, unfortunately. Due to an extremely long journal review process, our experiments were actually conducted more than 2 years ago, so the details of the experiments are not very fresh. All scripts and results from our simulations experiments can be found here https://github.com/NorskRegnesentral/shapr/tree/static_paper_experiments/inst/paper_experiments. It's a bit of a mess, so I can try to dig up any specific scripts you might be interested in -- just let me know.

  2. Great that you managed to reproduce the stepwise function and you got an interesting find there. I guess we should check our individual Shapley estimates to see if the same thing appeared in our 10-dim example.

  3. SHAFF is totally new to me, but looks interesting. Will have to look into it when I get time. Regarding the direct estimation of the conditional expectation, I did some simple experiments on that using a simple random forest and some GAMs, but for some reason (which is still a puzzle to me) I did not work well, so I eventually left the idea. Apparently, this paper tries to estimate the conditional expectation directly using a neural network: https://arxiv.org/abs/2006.01272 but there are lots of missing details I have not been able to understand why it worked well for them and in my simple examples.

  4. Please note that after we did our experiments on TreeSHAP, other variations on the algorithm have been added to the python shap library (see https://shap.readthedocs.io/en/latest/generated/shap.explainers.Tree.html#shap.explainers.Tree). We got our TreeSHAP estimates from the xgboost library in R TreeSHAP we used to do that, so as of now I don't know whether it corresponds to the "interventional" approach or the "tree_path_dependent" approach. I do think it is the latter, which is NOT the default in the python shap library, but I think it is worth double checking this.

@samkodes
Copy link
Author

Thanks for the response, and for sharing the scripts. I will take a look - and perhaps continue the investigation. The TreeSHAP implementation I am using is the R 'TreeSHAP' package, which I am pretty sure uses the "tree path dependent" approach; an "interventional" approach would not be expected to capture any dependencies in the covariate distribution.

It turns out that the SHAFF code here https://gitlab.com/drti/shaff is, like the paper, only an implementation of a global feature importance / variance estimate. Since I can't really hack C++ anyway, I've been experimenting with estimating pointwise conditional expectations using the "projected forest" idea from the SHAFF paper in my own simple R implementation, which takes a ranger
forest as input.

On the stepwise function in 7 dimensions, the first result is encouraging. If I understand correctly the usual kernel method (or perhaps the importance sampling method from the SHAFF paper) could be used to estimate SHAP values building on a subset only of the conditional expectations.

I am sampling 100 data points from the 7-dimensional data-generating Gaussian (off-diagonal rho=0.5) and for each point, estimating the conditional expectation of the stepwise function when the point is projected in each of the 128 (2^7) axis-aligned subspaces. For the projected-forest approach, the conditional expectation is estimated as the mean of observed outcomes for all data points falling in the intersection of all projected nodes that contain the test point; in practice we ensure sufficient data size / non-zero intersection by breadth-first traversal of the tree one level at a time and stopping when the current data gets too small. I'm then graphing the projected-forest (y axis) estimate of the conditional expectation against the exact (Monte-Carlo) conditional expectation (x-axis). Colouring is by the dimension of the subspace. The ranger forest is 400 trees fitted to 2000 data points.

In the first graph, you can see that the agreement is generally very good, with some structured noise around the extreme values - the kind of thing you might expect from a random forest's discretization of the space.

In the second graph, where I facet by test data point, you can see that this noise is due to specific test points which, I would assume, are in regions of the space where the function is not represented well by the random forest.

In the third graph, where I facet by subspace, you can see that agreement strength is not greatly affected by the choice of subspace.

shaff-test-1
shaff-test-2
shaff-test-3

Here is sample code to make ideas concrete. Note the above experiments were not restricted to out-of-bag data for estimating conditional means (Athey & Wager's 'honest trees') but I doubt it matters much - this just gave me more data to work with.

# calculate conditional expectation after projection of a test point into a subspace using a ranger-fitted random forest as an adaptive nearest-neighbourhood
# this code implements the Projected Forest idea presented in the SHAFF paper`https://arxiv.org/abs/2105.11724 to estimate pointwise conditional expectations
# code by Sam Kaufman uhwuhna@gmail.com  6/22/2021

projPointCondExpOOB <- function( tree, # the tree object
                             current_depth, # counter
                             current_node_list, # nodes active for some copies of x
                             current_node_boundaries, # boundaries of current nodes
                             U, # subspace we are projecting in
                             test_point, # point being projected, whose values in U we are conditioning on
                             OOB_min_intersection_size, # controls max depth of projection
                             OOB_data_matrix, # for calculating conditional OOB expectation of intersection   ; could shrink with each level descent - may speed things up
                             # this matrix should have dp ids as last column, y as second-to-last-column
                             debug.print=0
){
# tree has tree$child.nodeIDs (two lists for left and right, 0 means leaf), tree$split.varIDs, tree$split.values
# smaller or equal go to left
# first list is left
# can check my interpretation using "treeInfo" function or predict(type=terminalNodes)
# 0 for a child.nodeID means terminal ; 
# nodes are 0-indexed though, so root node is node 0!
# split.varIDs are also 0-indexed


  if(debug.print)
    print(paste("Depth = ", current_depth, "n_current_nodes = ", length(current_node_list), "n_OOB_remaining=", nrow(OOB_data_matrix)))
  new_current_node_list <- list()
  new_current_node_boundaries <- list()

 # special case for null subspace - just return mean of OOB_data_matrix 
 if(sum(U)==0){
   if(debug.print)
     print("Projecting into null subspace! Returning mean of OOB data")
   return( list(OOB_samp_too_small=FALSE, 
               remaining_OOB_IDs=OOB_data_matrix[,ncol(OOB_data_matrix)], 
               b_int=list(b_min=rep(-Inf,length(test_point)),
                          b_max=rep(Inf,length(test_point))), 
               mean.OOB=mean(OOB_data_matrix[,ncol(OOB_data_matrix)-1]), 
                n.OOB=nrow(OOB_data_matrix)))
  }

  # check current node_boundaries for intersection in U variables
n_dims<- length(test_point)
int_max <- rep(Inf,n_dims)
int_min <- rep(-Inf,n_dims)
for(i in 1:length(current_node_boundaries)){
  nb_i <- current_node_boundaries[[i]]
  int_max <- pmin(int_max,nb_i$b_max)
  int_min <- pmax(int_min,nb_i$b_min)
}
OOB_remaining <- OOB_data_matrix
for(i in 1:n_dims){
  # filter U variables only !! 
  if(U[[i]]){
    OOB_col_i <- OOB_remaining[,i]
    OOB_remaining_row_i <- (OOB_col_i <= int_max[[i]] & OOB_col_i > int_min[[i]])
    OOB_remaining <- OOB_remaining[OOB_remaining_row_i,,drop=FALSE]
  }
}
if(nrow(OOB_remaining)<OOB_min_intersection_size){
  if(debug.print)
    print("OOB sample to0 small, returning")
   return(list(OOB_samp_too_small=TRUE))
 }
if(debug.print)
  print(paste("After intersecting, there are ", nrow(OOB_remaining)," data points remaining"))

n_new_nodes <- 0
n_leaves <- 0 

for( i in 1:length(current_node_list)){
  node_i_ID <- current_node_list[[i]]$nodeID  
   node_i_isleaf <- tree$child.nodeIDs[[1]][[node_i_ID+1]] ==0 # leafs have children 0-coded 
  if(node_i_isleaf){
     # simply copy this node and its boundary to new node list, but count it b/c if all new nodes are leaves we need to stop recursing
     n_leaves <- n_leaves+1
     n_new_nodes <- n_new_nodes+1
     new_current_node_list[[n_new_nodes]] <- current_node_list[[i]]
     new_current_node_boundaries[[n_new_nodes]] <- current_node_boundaries[[i]]
    next
  }

  node_i_b <- current_node_boundaries[[i]]

  node_i_splitvar <- tree$split.varID[[node_i_ID+1]]+1  # add 1 both times since ranger uses 0-indexing

  if( U[[node_i_splitvar]] ){ # if the split is in U, we need to test test_point and choose the appropriate child for the next level
    node_i_splitval <- tree$split.values[[node_i_ID+1]]
    if( test_point[[node_i_splitvar]]<= node_i_splitval){
      node_i_b$b_max[[node_i_splitvar]] <- pmin( node_i_b$b_max[[node_i_splitvar]], node_i_splitval)
      new_node_id <- tree$child.nodeIDs[[1]][[node_i_ID+1]]
    } else {
      node_i_b$b_min[[node_i_splitvar]] <- pmax( node_i_b$b_min[[node_i_splitvar]], node_i_splitval)
      new_node_id <- tree$child.nodeIDs[[2]][[node_i_ID+1]]
    }
    n_new_nodes <- n_new_nodes+1
    new_current_node_list[[n_new_nodes]] <- list(nodeID=new_node_id)
    new_current_node_boundaries[[n_new_nodes]] <- node_i_b
  
  } else { # otherwise we add both children to the list for the next level
    node_i_splitval <- tree$split.values[[node_i_ID+1]]
    node_i_b_l <- node_i_b
    node_i_b_r <- node_i_b
  
    # don't bother
    #node_i_b_l$b_max[[node_i_splitvar]] <- pmin( node_i_b$b_max[[node_i_splitvar]], node_i_splitval)
    new_node_id_l <- tree$child.nodeIDs[[1]][[node_i_ID+1]]
  
    # don't bother
    #node_i_b_r$b_min[[node_i_splitvar]] <- pmax( node_i_b$b_min[[node_i_splitvar]], node_i_splitval)
    new_node_id_r <- tree$child.nodeIDs[[2]][[node_i_ID+1]]
  
    n_new_nodes <- n_new_nodes+1
    new_current_node_list[[n_new_nodes]] <- list(nodeID=new_node_id_l)
    new_current_node_boundaries[[n_new_nodes]] <- node_i_b_l
  
  
    n_new_nodes <- n_new_nodes+1
    new_current_node_list[[n_new_nodes]] <- list(nodeID=new_node_id_r)
    new_current_node_boundaries[[n_new_nodes]] <- node_i_b_r
  
  
  }
}

# Are all new nodes copied leaves? if so, return mean of OOB, list of remaining OOB ids, and 
 if( n_leaves == n_new_nodes){
  if(debug.print)
    print("All new nodes are leaves passed in")
  remaining_OOB_IDs <- OOB_remaining[,n_dims+2]
   b_int <- list(b_max=int_max, b_min=int_min)
   mean.OOB = mean(OOB_remaining[,n_dims+1])
  n.OOB <- nrow(OOB_remaining)
   return( list(OOB_samp_too_small=FALSE, remaining_OOB_IDs=remaining_OOB_IDs, b_int=b_int, mean.OOB=mean.OOB, n.OOB=n.OOB, max_depth=current_depth))
  } else {     # otherwise, recurse - breadth-first  
     
  ret.recurse <- projPointCondExpOOB( tree, # the tree object
                                     current_depth+1, # counter
                                    new_current_node_list, # nodes active for some copies of x
                                    new_current_node_boundaries, # boundaries of current nodes
                                    U, # subspace we are projecting in
                                    test_point, # point being projected, whose values in U we are conditioning on
                                    OOB_min_intersection_size, # controls max depth of projection
                                    OOB_remaining # for calculating conditional OOB expectation of intersection   ; could shrink with each level descent - may speed things up
                                    # this matrix should have dp ids as last column, y as second-to-lats-column)
  )
  # check if sample is too small and if needed return this level's estimate
  if(ret.recurse$OOB_samp_too_small){
     remaining_OOB_IDs <- OOB_remaining[,n_dims+2]
     b_int <- list(b_max=int_max, b_min=int_min)
     mean.OOB = mean(OOB_remaining[,n_dims+1])
     n.OOB <- nrow(OOB_remaining)
    return( list(OOB_samp_too_small=FALSE, remaining_OOB_IDs=remaining_OOB_IDs, b_int=b_int, mean.OOB=mean.OOB, n.OOB=n.OOB, max_depth=current_depth))
  } else {
    return(ret.recurse)
  }
 }

  } 


calcProjForestPointCondExpOOB <- function( R.forest,
                                       R.forest.inbag.counts,
                                       U, # subspace we are projecting in
                                       test_point, # point being projected, whose values in U we are conditioning on
                                       OOB_min_intersection_size, # controls max depth of projection 
                                       for.data.mat, # need to supply data as matrix with last col outcome
                                       use.inbag=TRUE,
                                       debug.print=0,
                                       print.every=50
){
  n.trees <- R.forest$num.trees
  tree.results <- list()
 for(i in 1:n.trees){
  if(debug.print>0){
     if( !(i%%print.every)){
      print(paste("condexpOOB for tree",i,"of",n.trees))
    }
  }
  tree_i <- list( child.nodeIDs=R.forest$child.nodeIDs[[i]],
                  split.varIDs=R.forest$split.varIDs[[i]],
                  split.values=R.forest$split.values[[i]] )
  if(use.inbag){
    OOB_data_indices <- 1:nrow(for.data.mat)
  } else {
    OOB_data_indices <- which(R.forest.inbag.counts[[i]]==0) 
  
  }
  OOB_data_i <- matrix(NA ,nrow=length(OOB_data_indices), ncol=ncol(for.data.mat)+1)
  OOB_data_i[ ,1:ncol(for.data.mat)] <- for.data.mat[OOB_data_indices,]
  OOB_data_i[,ncol(for.data.mat)+1] <- 1:length(OOB_data_indices)
  tree.results[[i]] <- projPointCondExpOOB( tree_i, # the tree object
                                          0, # counter
                                          list(list(nodeID=1)), # nodes active for some copies of x
                                          list(list(b_min=rep(-Inf,length(test_point)), b_max=rep(Inf,length(test_point)))), # boundaries of current nodes
                                          U, # subspace we are projecting in
                                          test_point, # point being projected, whose values in U we are conditioning on
                                          OOB_min_intersection_size, # controls max depth of projection
                                          OOB_data_i ,# for calculating conditional OOB expectation of intersection   ; could shrink with each level descent - may speed things up
                                          # this matrix should have dp ids as last column, y as second-to-lats-column)
                                          debug.print=pmax(debug.print-1,0)
  )
}
avg_cond_exp <- mean(sapply( tree.results, function(x){x$mean.OOB}))
return(list(avg_cond_exp=avg_cond_exp,tree.results=tree.results))

}

@salimamoukou
Copy link

Hi @samkodes @martinju

In this paper [https://arxiv.org/pdf/2106.03820.pdf], you will find an empirical and theoretical analysis of the bias of the dependent TreeSHAP agorithm. It shows that the TreeSHAP algorithm is very biased, even false, and proposes 2 new accurate estimators to compute the conditional probabilities for tree-based models. The algorithms are implemented in the ACV package: [https://github.com/salimamoukou/acv00].

@samkodes
Copy link
Author

samkodes commented Jun 23, 2021

Hi @salimamoukou! Very interesting!

This clarifies some things re TreeSHAP - I initially had to struggle to understand why you claimed that TreeSHAP did not estimate your "reduced predictor" above (3.2) but now I understand - the issue is that the weights TreeSHAP assign to each leaf are path-dependent, as you explain. If P(x \in L) = p(x in I_d x in I_(d-1)) * p(x in I_(d-1) | x in I_(d-2)) * .... $, which is just iterated conditionals, TreeSHAP replaces all terms with splits in the variable subset by 1. So this is no longer clearly interpretable as a conditional probability. Your "Leaf estimator" weighting in contrast is (for I_i splits above L in the the variable subset) P(X in L| Conjunction_i x in I_i) = P(X in L AND conjunction_i x in I_i) / P(conjunction_i x in I_i), but since the AND terms are proper subsets, we have P(X in L|Conjunction_i x in I_i) = P(X in L) / P(conjunction_i x in I_i). Depending on the tree structure, the second denominator is not the same as the product of the removed (set to 1) conditional terms in the TreeSHAP calculation. And even the Leaf estimator is not really the "reduced predictor".

I see two relevant points of contact with the SHAFF projected forest idea.

First, both the path-dependency in TreeSHAP and the different conditioning for each leaf in the Leaf estimator might be avoided by using an intersection method as the SHAFF paper suggests. In this case we could consider only the leaves that overlap with the intersection of all splits for the relevant subspace encountered on all of the tree paths the test point goes down, and condition on the intersection of the projected splits. The downside is that a depth limit becomes necessary to ensure sufficient sample data size and non-empty intersection.

The second issue is the approach of weighting the leaf estimates, assumed by both TreeSHAP and your proposal. My concern is that certain leaves included in your C(S,x) (compatible leaves) may be elongated in the S subspace; in other words, they may include many points whose projections are not "local" to the projection of the test-point x. Using the leaf estimate for such a leaf will include the values of points that may not be appropriate for the conditional expectation. Though the premise of a tree model is that the underlying function is constant on the entire leaf region, there is always error, and the weight assigned to the leaf will affect how serious the error is. And even if there is zero error with respect to the underlying function, the weighting scheme will affect the accuracy of the conditional expectation. A weighting scheme that gives high weight to an elongated leaf which has most of its data far away from the pre-image of the test point may create bias (see below).

Rather than using these leaf values as given and reweighting, the SHAFF approach appears to construct a new neighbourhood using the intersection of the projections C(S,x). This reminds me of the adaptive-neighbourhood view of a Random Forest, which takes the leafs as defining a neighbourhood of a test point, and averages over trees in the forest to define a weighting scheme; data points are weighted according to the proportion of trees in which they share a leaf node with the test point. The generalized random forest (grf) approach of Athey, Wager, and Tibshirani adopts this adaptive-neighbourhood perspective and uses the weights to fit a new distribution for each prediction (a similar approach is taken by the "distributional forest" approach of Schlosser, Hothorn, Stauffer, and Zeileis). If an intersection is used as in SHAFF, this intersection neighbourhood can be the new neighbourhood for averaging adapted to the test point and variable subspace.

Third, I wonder why the weighting scheme you propose seems to work. (I understand the intuition - that it is as close to conditioning on the exact value of the subspace coordinate as is allowed by the given compatible leaf.) I can imagine situations where the weighting reduces the bias from elongated leaves -- if for example, the leaf under consideration is of relatively low data density relative to other leaves it overlaps with in the projection. However if the leaf under consideration is of relatively high data density relative to other leaves it overlaps with (but with low data density near the pre-image of the projected test point), I can imagine the weighting scheme creating very high bias. The effect of the weighting scheme would seem to be very dependent on the tree structure (including leaf values) and the data density. Why not instead weight each leaf by its share of the data in all of C(S,x) (all compatible leaves)?

I am not sure that "elongated leaves" is an issue in reality, though trees where it is a problem can easily be constructed.
These are my first thoughts only - I will try to read the rest of the paper (which looks like it has lots of other interesting stuff in it) and experiment with the code. Thanks!

@samkodes
Copy link
Author

Here is an example of the "elongated leaves" problem in 2 dimensions. The tree has 4 leaves (heavy lines) and the true function is constant on each leaf. The selected variable subspace is the horizontal subspace. The test point x_1 is compatible with leaf A and leaf B. Data points are the blue circles. Depending on our definition of "local" in the horizontal subspace, very different conditional expectations will arise. The "Leaf estimator" will give very high weight to f(A)=2 and relatively low weight to f(B)=1. An intersection (SHAFF) approach would give equal weight to f(A) and f(B), because they each have 2 data points in the projected intersection (between the orange dotted lines)
elongated leaves

@martinju
Copy link
Member

Thanks for the input @salimamoukou! I actually came across you paper a week ago or so, but didn't remember the TreeSHAP analysis. I have to read up a bit more to understand all aspects a bit better, but the last figure and example you present here @samkodes is very illustrative!

On a related note, the shapr package also has a tree based approach implemented for estimating conditional distributions, see https://arxiv.org/pdf/2007.01027.pdf. The method is based off conditional inference trees (see https://www.jstor.org/stable/pdf/27594202.pdf for all the math) which should have desirable properties for such tasks as opposed to standard CART at least. Note that the approach we take is to use this to estimate the conditional distirbution separatedly for all differnet feature subsets, no matter what the model to explain f(x) really is.

@salimamoukou
Copy link

salimamoukou commented Jul 5, 2021

Hi, sorry for the late response,

First of all, thank you for all the interesting references.

The SHAFF approach seems very specific to RandomForest (not generalizable to boosting trees for example). Moreover their goals are different from ours. They try to estimate the true "generative" conditional expectation E[Y|X_S] and not the reduced predictor: E[f(X) | X_S]. In the case of the reduced predictor, the only variant is the weighting (not the leaves values).

However, an adaptation of the projected forest as you said (@samkodes : Why not instead weight each leaf by its share of the data in all of C(S,x)?) seems to be a very good idea to solve the elongated leaves problem. The intuition convinces me but I don't know yet how to justify it theoretically. Anyway, can't wait to test this empirically. I'll keep you in touch.

These are my first thoughts after a first reading of the papers: there may be some mistakes.

@martinju
Copy link
Member

martinju commented Jul 8, 2021

Hi again @samkodes and @salimamoukou, and thank you for these interesting discussions!
I finally had a chance to look properly at SHAFF, and I agree that there are very interesting and promising tricks and ideas there! Regarding @salimamoukou comment:

Moreover their goals are different from ours. They try to estimate the true "generative" conditional expectation E[Y|X_S] and not the reduced predictor: E[f(X) | X_S]. In the case of the reduced predictor, the only variant is the weighting (not the leaves values).

I believe that this could be easily fixed by replacing Y by f(X), right? At least I don't see why at first sight. Their proof falls short whenever f is not continuous though, but I am sure one could get around that with a "piecewise continuous with finely many discontinuities"-trick instead. Or am I missing something here? In my mind, you then just use the random forest model (SHAFF) to estimate the conditional expectations, and that the actual model you are explaining could be anything (including boosted trees). I do find it hard to get from the SHAFF paper whether this is what the authors are trying to say or not, though.

As mentioned in point 3 of my first comment in this thread, I tried building a random forest to estimate E[f(X)|X_S] a few years back without any luck -- however, I did it brute force with no projection forests or importance sampling, so my issue might have been in the architecture.

I am also very intrigued by the elongated leaves problem, so please keep me informed of any tests on that end! @salimamoukou: I will also take a closer look at your paper and methods for tree based estimators of the conditional expectations -- it looks interesting!

@samkodes
Copy link
Author

samkodes commented Jul 9, 2021

I think the SHAFF authors are thinking of using the partition of space learned by a forest model that's already been fit, simply for efficiency, but I think you're right, you could fit a forest to any other model you want to analyze and then run SHAFF on it. I suppose for efficiency the key would be to fit a single forest and then use the projection algorithm for each (sampled) subspace. Because the projection algorithm works on single trees, you could always try to apply it directly to boosted trees (for example), but of course they don't prove anything for that case, and I have no idea whether the proofs would go through (I'd guess not). The projected forest idea came out of an early paper on a different variable importance measure (https://arxiv.org/abs/2102.13347), and there are some comments there comparing it to TreeSHAP.

I did try running some benchmarks comparing SHAFF with your "Leaf estimator" @salimamoukou on the 7-dim version of the piecewise constant example from @martinju 's paper, as well as comparing with a few different "Leaf estimator" variations. I will post if I have a little more time to revisit. SHAFF seemed to beat the "Leaf estimator" until the dimensionality of the projected subspace got high, but it's not clear whether this is because the "minimum intersection size" in the SHAFF algorithm made it lose precision. Dimensionality may be a difficult problem for an intersection approach. Depending on the subspace sampling/weighting used (KernelSHAP or SHAFF's importance sampling method) this might have different implications for accuracy. Variations on the Leaf estimator (like share of data in C(X,S), or weighting by "subspace aspect ratio" i.e. the leaf's "vertical" (complement subspace) marginal density divided by its "horizontal" (selected subspace) marginal density, or data-weighted "aspect ratio" ) seemed to perform worse than the proposed Leaf estimator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants