Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Y.hat input for causal_survival_forest()? #1410

Open
bcjaeger opened this issue Apr 22, 2024 · 4 comments
Open

Allow Y.hat input for causal_survival_forest()? #1410

bcjaeger opened this issue Apr 22, 2024 · 4 comments
Labels
experimental An experimental feature branch not intended for merge. feature

Comments

@bcjaeger
Copy link
Contributor

Hello,

Thank you for developing grf. It's great!

Would it be feasible to allow Y.hat to be an input for causal survival forests? I see from code in causal_survival_forest that a couple of intermediate values are taken from the output of survival_forest(), so it may not be straightforward to just plug in Y.hat from a separate routine. I am interested in using a forest object from aorsf::orsf()(https://github.com/ropensci/aorsf). If aorsf::orsf() could provide those intermediate outputs, would it be feasible to allow a forest object from aorsf to be used here?

@erikcs
Copy link
Member

erikcs commented Apr 23, 2024

Hi @bcjaeger, thank you! Oblique forests look very interesting!

It's a good question, causal_survival_forests has many complicated nuisance components, and for simplicity we opted out of user-specified estimates. I think one route to go could be to add a specialized causal_survival_forest.fit entry point.

But before going down that road, how about stitching together your own causal_survival_forest and trying that out first? Here is a template for a causal_survival_forest.custom that you can copy directly into an R session and experiment with (using ::: to access grf's private methods).

I added 5 TODO comments where you could replace grf's survival-based estimates with your own estimates, please let me know if any of these are unclear.

@bcjaeger
Copy link
Contributor Author

Awesome idea. I checkout out the TODO steps and everything looked very clear. I will try this soon and let you know how it goes. Thank you!

@bcjaeger
Copy link
Contributor Author

I've made a little progress. The reprex below runs with the development version of aorsf, but will not run with the current version on CRAN. I made a small update to allow out-of-bag predictions on modified versions of the training data. I fiddled with some real data and encountered some efficiency issues with large Y.grid, so I put in two additions (I added comments that start with 'bcj addition') that were not part of the TODO's you left.

causal_survival_forest.custom <- function(
    X, Y, W, D,
    W.hat = NULL,
    target = c("RMST", "survival.probability"),
    horizon = NULL,
    failure.times = NULL,
    num.trees = 2000,
    sample.weights = NULL,
    clusters = NULL,
    equalize.cluster.weights = FALSE,
    sample.fraction = 0.5,
    mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
    min.node.size = 5,
    honesty = TRUE,
    honesty.fraction = 0.5,
    honesty.prune.leaves = TRUE,
    alpha = 0.05,
    imbalance.penalty = 0,
    stabilize.splits = TRUE,
    ci.group.size = 2,
    tune.parameters = "none",
    compute.oob.predictions = TRUE,
    num.threads = NULL,
    seed = runif(1, 0, .Machine$integer.max)) {

  target <- match.arg(target)
  if (is.null(horizon) || !is.numeric(horizon) || length(horizon) != 1) {
    stop("The `horizon` argument defining the estimand is required.")
  }

  has.missing.values <- grf:::validate_X(X, allow.na = TRUE)
  grf:::validate_sample_weights(sample.weights, X)
  Y <- grf:::validate_observations(Y, X)
  W <- grf:::validate_observations(W, X)
  D <- grf:::validate_observations(D, X)
  clusters <- grf:::validate_clusters(clusters, X)
  samples.per.cluster <- grf:::validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
  num.threads <- grf:::validate_num_threads(num.threads)
  if (any(Y < 0)) {
    stop("The event times must be non-negative.")
  }
  if (!all(D %in% c(0, 1))) {
    stop("The censor values can only be 0 or 1.")
  }
  if (sum(D) == 0) {
    stop("All observations are censored.")
  }
  if (target == "RMST") {
    # f(T) <- min(T, horizon)
    D[Y >= horizon] <- 1
    Y[Y >= horizon] <- horizon
    fY <- Y
  } else {
    # f(T) <- 1{T > horizon}
    fY <- as.numeric(Y > horizon)
  }
  if (is.null(failure.times)) {

    Y.grid <- sort(unique(Y))

    # bcj addition 1
    # large Y.grid can slow computation down.
    # consider simplifying if Y.grid is > 100 points
    if(length(Y.grid) > 100){
      Y.grid <- seq(min(Y.grid), max(Y.grid), length.out = 100)
    }

  } else if (min(Y) < min(failure.times)) {
    stop("If provided, `failure.times` should be a grid starting on or before min(Y).")
  } else {

    # bcj addition 2
    # consider simplifying for computational efficiency
    if(length(failure.times) > 100){
      failure.times.orig <- failure.times
      failure.times <- seq(min(failure.times),
                           max(failure.times),
                           length.out = 100)

      # make the subsetted failure times be a proper subset of the original
      # failure times. It would be more efficient to sample the original,
      # but this helps to ensure the grid is more evenly spaced.
      for(i in seq_along(failure.times)){
        closest_index <- which.min(abs(failure.times[i] - failure.times.orig))
        failure.times[i] <- failure.times.orig[closest_index]
      }

      # in case there were duplicates introduced.
      failure.times <- unique(failure.times)

    }

    Y.grid <- failure.times
  }
  if (length(Y.grid) <= 2) {
    stop("The number of distinct event times should be more than 2.")
  }
  if (horizon < min(Y.grid)) {
    stop("`horizon` cannot be before the first event.")
  }
  if (nrow(X) > 5000 && length(Y.grid) / nrow(X) > 0.1) {
    warning(paste0("The number of events are more than 10% of the sample size. ",
                   "To reduce the computational burden of fitting survival and ",
                   "censoring curves, consider discretizing the event values `Y` or ",
                   "supplying a coarser grid with the `failure.times` argument. "), immediate. = TRUE)
  }

  if (is.null(W.hat)) {
    forest.W <- grf::regression_forest(X, W, num.trees = max(50, num.trees / 4),
                                       sample.weights = sample.weights, clusters = clusters,
                                       equalize.cluster.weights = equalize.cluster.weights,
                                       sample.fraction = sample.fraction, mtry = mtry,
                                       min.node.size = 5, honesty = TRUE,
                                       honesty.fraction = 0.5, honesty.prune.leaves = TRUE,
                                       alpha = alpha, imbalance.penalty = imbalance.penalty,
                                       ci.group.size = 1, tune.parameters = tune.parameters,
                                       compute.oob.predictions = TRUE,
                                       num.threads = num.threads, seed = seed)
    W.hat <- predict(forest.W)$predictions
  } else if (length(W.hat) == 1) {
    W.hat <- rep(W.hat, nrow(X))
  } else if (length(W.hat) != nrow(X)) {
    stop("W.hat has incorrect length.")
  }
  W.centered <- W - W.hat

  args.nuisance <- list(failure.times = failure.times,
                        num.trees = max(50, min(num.trees / 4, 500)),
                        sample.weights = sample.weights,
                        clusters = clusters,
                        equalize.cluster.weights = equalize.cluster.weights,
                        sample.fraction = sample.fraction,
                        mtry = mtry,
                        min.node.size = 15,
                        honesty = TRUE,
                        honesty.fraction = 0.5,
                        honesty.prune.leaves = TRUE,
                        alpha = alpha,
                        prediction.type = "Nelson-Aalen", # to guarantee non-zero estimates.
                        compute.oob.predictions = TRUE,
                        num.threads = num.threads,
                        seed = seed)

  # Compute survival-based nuisance components (https://arxiv.org/abs/2001.09887)
  # m(x) relies on the survival function conditional on only X, while Q(x) relies on the conditioning (X, W).
  # Instead of fitting two separate survival forests, we can use the forest fit on (X, W) to compute m(X)
  # using the identity
  # E[f(T) | X] = e(X) E[f(T) | X, W = 1] + (1 - e(X)) E[f(T) | X, W = 0]
  # (for this to work W has to be binary).

  # orsf would throw an error if columns were unnamed
  if(is.null(colnames(X))) colnames(X) <- paste("x", seq(ncol(X)), sep = "_")

  orsf_data <- as.data.frame(cbind(y = Y, d = D, w = W, X))

  # this is to prevent aorsf from throwing an error when it
  # encounters event times of 0. I should remove this assertion
  # from aorsf, but for now:
  orsf_data$y <- pmax(orsf_data$y, .Machine$double.eps)

  # default is to use unique event times
  if(is.null(args.nuisance$failure.times)){
    args.nuisance$failure.times <- sort(unique(Y[D==1]))
  }

  # plugging in inputs from args.nuisance where possible
  sf.survival <- aorsf::orsf(
    data = orsf_data,
    formula = y + d ~ .,
    n_tree = args.nuisance$num.trees,
    weights = args.nuisance$sample.weights,
    sample_fraction = args.nuisance$sample.fraction,
    mtry = args.nuisance$mtry,
    # using min node size for both leaf stopping criteria.
    # This will usually lead to more shallow trees.
    leaf_min_obs = args.nuisance$min.node.size,
    leaf_min_events = args.nuisance$min.node.size,
    oobag_pred_type = "surv",
    oobag_pred_horizon = args.nuisance$failure.times,
    tree_seeds = round(args.nuisance$seed)
  )

  binary.W <- all(W %in% c(0, 1))

  if (binary.W) {

    # The survival function conditioning on being treated S(t, x, 1) estimated with an "S-learner".
    # Computing OOB estimates for modified training samples is not a workflow we have implemented,
    # so we do it with a manual workaround here (deleting/re-inserting precomputed predictions)

    orsf_data$w <- 1
    S1.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)
    orsf_data$w <- 0
    S0.hat <- predict(sf.survival, new_data = orsf_data, oobag = TRUE)
    orsf_data$w <- W

    if (target == "RMST") {
      Y.hat <- W.hat * grf:::expected_survival(S1.hat, sf.survival$pred_horizon) +
        (1 - W.hat) * grf:::expected_survival(S0.hat, sf.survival$pred_horizon)
    } else {
      horizonS.index <- findInterval(horizon, sf.survival$pred_horizon)
      if (horizonS.index == 0) {
        Y.hat <- rep(1, nrow(X))
      } else {
        Y.hat <- W.hat * S1.hat[, horizonS.index] + (1 - W.hat) * S0.hat[, horizonS.index]
      }
    }

  } else {
    # Ignoring this code branch for the simplicity's sake
    stop("Custom survival models + continuous treatment not implemented")

    # If continuous W fit a separate survival forest to estimate E[f(T) | X].
    # sf.Y <- do.call(grf::survival_forest, c(list(X = X, Y = Y, D = D), args.nuisance))
    # SY.hat <- predict(sf.Y)$predictions
    # if (target == "RMST") {
    #   Y.hat <- expected_survival(SY.hat, sf.Y$failure.times)
    # } else {
    #   horizonS.index <- findInterval(horizon, sf.survival$failure.times)
    #   if (horizonS.index == 0) {
    #     Y.hat <- rep(1, nrow(X))
    #   } else {
    #     Y.hat <- SY.hat[, horizonS.index]
    #   }
    # }
  }

  # The conditional survival function S(t, x, w) used to construct Q(x).
  S.hat <- predict(sf.survival, oobag = TRUE, pred_horizon = Y.grid)

  if (!identical(dim(S.hat), c(length(Y), length(Y.grid)))) stop("Wrong S.hat prediction dims")

  # The conditional survival function for the censoring process S_C(t, x, w).
  orsf_data$d <- 1 - D

  sf.censor <- aorsf::orsf_update(sf.survival,
                                  data = orsf_data,
                                  # default split_min_stat is about 3,
                                  # setting to 10 makes trees more shallow
                                  split_min_stat = 10,
                                  oobag_pred_horizon = Y.grid)

  C.hat <- sf.censor$pred_oobag

  if (!identical(dim(C.hat), c(length(Y), length(Y.grid)))) stop("Wrong C.hat prediction dims")

  if (target == "survival.probability") {
    # Evaluate psi up to horizon
    D[Y > horizon] <- 1
    Y[Y > horizon] <- horizon
  }

  Y.index <- findInterval(Y, Y.grid) # (invariance: Y.index > 0)
  C.Y.hat <- C.hat[cbind(seq_along(Y.index), Y.index)] # Pick out P[Ci > Yi | Xi, Wi]

  if (target == "RMST" && any(C.Y.hat <= 0.05)) {
    warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
                  "- an identifying assumption is that there exists a fixed positive constant M",
                  "such that the probability of observing an event past the maximum follow-up time ",
                  "is at least M (i.e. P(T > horizon | X) > M).",
                  "This warning appears when M is less than 0.05, at which point causal survival forest",
                  "can not be expected to deliver reliable estimates."), immediate. = TRUE)
  } else if (target == "RMST" && any(C.Y.hat < 0.2)) {
    warning(paste("Estimated censoring probabilities are lower than 0.2",
                  "- an identifying assumption is that there exists a fixed positive constant M",
                  "such that the probability of observing an event past the maximum follow-up time ",
                  "is at least M (i.e. P(T > horizon | X) > M)."))
  } else if (target == "survival.probability" && any(C.Y.hat <= 0.001)) {
    warning(paste("Estimated censoring probabilities go as low as:", round(min(C.Y.hat), 5),
                  "- forest estimates will likely be very unstable, a larger target `horizon`",
                  "is recommended."), immediate. = TRUE)
  } else if (target == "survival.probability" && any(C.Y.hat < 0.05)) {
    warning(paste("Estimated censoring probabilities are lower than 0.05",
                  "and forest estimates may not be stable. Using a smaller target `horizon`",
                  "may help."))
  }

  psi <- grf:::compute_psi(S.hat, C.hat, C.Y.hat, Y.hat, W.centered,
                           D, fY, Y.index, Y.grid, target, horizon)
  grf:::validate_observations(psi[["numerator"]], X)
  grf:::validate_observations(psi[["denominator"]], X)

  data <- grf:::create_train_matrices(X,
                                      treatment = W.centered,
                                      survival.numerator = psi[["numerator"]],
                                      survival.denominator = psi[["denominator"]],
                                      censor = D,
                                      sample.weights = sample.weights)

  args <- list(num.trees = num.trees,
               clusters = clusters,
               samples.per.cluster = samples.per.cluster,
               sample.fraction = sample.fraction,
               mtry = mtry,
               min.node.size = min.node.size,
               honesty = honesty,
               honesty.fraction = honesty.fraction,
               honesty.prune.leaves = honesty.prune.leaves,
               alpha = alpha,
               imbalance.penalty = imbalance.penalty,
               stabilize.splits = stabilize.splits,
               ci.group.size = ci.group.size,
               compute.oob.predictions = compute.oob.predictions,
               num.threads = num.threads,
               seed = seed)

  forest <- grf:::do.call.rcpp(grf:::causal_survival_train, c(data, args))
  class(forest) <- c("causal_survival_forest", "grf")
  forest[["seed"]] <- seed
  forest[["_psi"]] <- psi
  forest[["X.orig"]] <- X
  forest[["Y.orig"]] <- Y
  forest[["W.orig"]] <- W
  forest[["D.orig"]] <- D
  forest[["Y.hat"]] <- Y.hat
  forest[["W.hat"]] <- W.hat
  forest[["sample.weights"]] <- sample.weights
  forest[["clusters"]] <- clusters
  forest[["equalize.cluster.weights"]] <- equalize.cluster.weights
  forest[["has.missing.values"]] <- has.missing.values
  forest[["target"]] <- target
  forest[["horizon"]] <- horizon

  forest
}


n <- 500
p <- 5
X <- matrix(runif(n * p), n, p)
W <- rbinom(n, 1, 0.5)
horizon <- 1
failure.time <- pmin(rexp(n) * X[, 1] + W, horizon)
censor.time <- 2 * runif(n)
Y <- round(pmin(failure.time, censor.time), 2)
D <- as.integer(failure.time <= censor.time)

# grf causal survival forest
csf.orig <- grf::causal_survival_forest(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.orig)
#>   estimate    std.err 
#> 0.59288638 0.02277714
head(predict(csf.orig))
#>   predictions
#> 1   0.5627453
#> 2   0.4707053
#> 3   0.6011034
#> 4   0.6892941
#> 5   0.5707982
#> 6   0.5513207

# your custom CS forest
csf.custom <- causal_survival_forest.custom(X, Y, W, D, horizon = horizon, seed = 42)
grf::average_treatment_effect(csf.custom)
#>   estimate    std.err 
#> 0.59504669 0.02245838
head(predict(csf.custom))
#>   predictions
#> 1   0.5710990
#> 2   0.4703588
#> 3   0.6050363
#> 4   0.6909468
#> 5   0.5722461
#> 6   0.5575252

Created on 2024-04-30 with reprex v2.1.0

@erikcs
Copy link
Member

erikcs commented Apr 30, 2024

Very cool! Yes the Y.grid shouldn't be too large, we emit a warning with suggestions if that is the case, but it could be nice further down the line to add some automated grid selection (PS: in case it's useful to keep in mind when doing your modifications: the CSF code expects both S.hat and C.hat to be indexed by the same time grid)

@erikcs erikcs added experimental An experimental feature branch not intended for merge. feature labels May 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
experimental An experimental feature branch not intended for merge. feature
Projects
None yet
Development

No branches or pull requests

2 participants