Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in the predict subroutine SurvData_HazardModel #94

Closed
prockenschaub opened this issue Apr 30, 2024 · 3 comments
Closed

Possible bug in the predict subroutine SurvData_HazardModel #94

prockenschaub opened this issue Apr 30, 2024 · 3 comments

Comments

@prockenschaub
Copy link

Problem

I am looking to use the models obtained via JMbayes2 for risk prediction. However, like others before me (#72, #81, #90), I noticed that the AUC of the joint model was considerably worse than that of a simple Cox regression using baseline variables. The explanations that had been given so far were that a) different definitions of AUC/c-index might have been used (#72, #90) and b) the longitudinal covariate may not be a good predictor (#81).

Additional experiments

In my case, the performance of the joint model was so much worse that I decided to dig a little deeper. In particular, I decided to look into the correlation of linear predictor at t=0 and the cumulative incidence provided by predict. Even if the longitudinal predictor is a bad predictor -- i.e., it does not add additional information -- we would expect these to be at least moderately correlated. I adapted the prediction tutorial as follows:

# Fit a single longitudinal model
fm2 <- lme(prothrombin ~ year * sex, data = pbc2,
           random = ~ year | id, control = lmeControl(opt = 'optim'))

# Fit a single Cox model
pbc2.id$event <- as.numeric(pbc2.id$status != "alive")
CoxFit <- coxph(Surv(years, event) ~ drug + age, data = pbc2.id)

# Use both in the joint model
jointFit <- jm(CoxFit, list(fm2), time_var = "year", cores = 1)

# Predict for the first 100 patients at time 0. 
# For simplicity, we can use pbc2.id here instead of subsetting
t0 <- 0
ND <- pbc2.id[1:100, ]
ND$event <- 0
ND$years <- t0

# Perform a prediction of the longitudinal marker using only information at 
# time 0. This gives us the value that the joint model is likely using under 
# the hood in prediction of the survival. Note that the prediction is close to 
# but not equal to the actually observed prothrombin value at time 0! 
predLong1 <- predict(jointFit, newdata = ND, return_newdata = TRUE, cores = 1)
predLong1

# Approximate a linear predictor from the coefs of `summary(jointFit)` and the 
# predicted values at time 0.
coefs = coef(jointFit)
lp = with(predLong1, coefs$gammas[1] * (drug == "D-penicil")  + coefs$gammas[2] * age + coefs$association[1] * pred_prothrombin) 

# Predict survival at time 5. For some reason, this also includes a prediction 
# at time 0? Filter away the prediction at time 0. 
predSurv <- predict(jointFit, newdata = ND, process = "event", times = 5, return_newdata = TRUE, cores = 1)
predSurv = predSurv[predSurv$year == 5, ]

# Compare the linear predictor and the CIF
cor(lp, predSurv$pred_CIF)

Counter to my expectations, there was practically no correlation ($\rho \approx 0.03)$ between linear predictor and cumulative incidence. This seemed odd.

Potential bug

I started debugging the above code to see what was going on under the hood. I stumbled across the SurvData_HazardModel(), which to the best of my (limited) understanding expands the prediction frame with the GK points that are used to calculate the hazard in between prediction times. The GK points are given in the times_to_fill.

There might be a bug when trying to do so. Consider predicting with the joint model from before, but for readability we will set GK_k = 7 and only predict for the first two patients:

debug(JMbayes2:::SurvData_HazardModel)
jointFit <- jm(CoxFit, list(fm2), time_var = "year", cores = 1, control = list(GK_k = 7))
predict(jointFit, newdata = ND[1:2, ], process = "event", times = 5, return_newdata = TRUE)

The first call to SurvData_HazardModel() occurs in prepare_Data_preds(), and there was no problem here in my example. We can just continue to debug by typing c. It the second call to SurvData_HazardModel() occurs in prepare_DataE_preds(), something seems to go awry. length(times_to_fill) is a multiple of nrow(data) and the dataset needs to be broadcast so that each patient is matched with their corresponding CK points. However, in our example, we end up with the following:

     id years status      drug      age    sex      year
1     1     0   dead D-penicil 58.76684 female 0.0000000
1.1   1     0   dead D-penicil 58.76684 female 0.0000000
1.2   1     0   dead D-penicil 58.76684 female 0.0000000
1.3   1     0   dead D-penicil 58.76684 female 0.0000000
1.4   1     0   dead D-penicil 58.76684 female 0.0000000
1.5   1     0   dead D-penicil 58.76684 female 0.0000000
1.6   1     0   dead D-penicil 58.76684 female 0.0000000
2     2     0  alive D-penicil 56.44782 female 0.1272302
2.1   2     0  alive D-penicil 56.44782 female 0.6461720
2.2   2     0  alive D-penicil 56.44782 female 1.4853871
2.3   2     0  alive D-penicil 56.44782 female 2.5000000
2.4   2     0  alive D-penicil 56.44782 female 3.5146129
2.5   2     0  alive D-penicil 56.44782 female 4.3538280
2.6   2     0  alive D-penicil 56.44782 female 4.8727698
1.7   1     0   dead D-penicil 58.76684 female 0.0000000
1.8   1     0   dead D-penicil 58.76684 female 0.0000000
1.9   1     0   dead D-penicil 58.76684 female 0.0000000
1.10  1     0   dead D-penicil 58.76684 female 0.0000000
1.11  1     0   dead D-penicil 58.76684 female 0.0000000
1.12  1     0   dead D-penicil 58.76684 female 0.0000000
1.13  1     0   dead D-penicil 58.76684 female 0.0000000
2.7   2     0  alive D-penicil 56.44782 female 0.1272302
2.8   2     0  alive D-penicil 56.44782 female 0.6461720
2.9   2     0  alive D-penicil 56.44782 female 1.4853871
2.10  2     0  alive D-penicil 56.44782 female 2.5000000
2.11  2     0  alive D-penicil 56.44782 female 3.5146129
2.12  2     0  alive D-penicil 56.44782 female 4.3538280
2.13  2     0  alive D-penicil 56.44782 female 4.8727698

Patient 1 ends up with all zeros, whereas Patient 2 gets the same 7 points between times 0-5 twice, whereas each patient should get one set of each. A rep(data, times=2) situation seems to be happening instead of the correct rep(data, each=2). This gives a clear reason why the correlation between linear predictor and cumulative incidence was so low above.

Potential solution

I haven't fully understood all the parts of JMbayes2 in which SurvData_HazardModel() is used. I also don't know if the only time when the bug arises is during prediction. I therefore was only able to try the following quick and dirty fix of SurvData_HazardModel

SurvData_HazardModel <- function (times_to_fill, data, times_data, ids,
                                   time_var, index = NULL, fix_it = FALSE) {
    unq_ids <- unique(ids)
    fids <- factor(ids, levels = unq_ids)
    # checks
    if (is.null(index)) {
        index <- match(ids, unq_ids)
    }
    if (length(times_to_fill) != length(unq_ids) && is.null(index)) {
        stop("length 'times_to_fill' does not match the length of unique 'ids'.")
    }
    if (nrow(data) != length(fids)) {
        stop("the number of rows of 'data' does not match the length of 'ids'.")
    }
    if (nrow(data) != length(times_data)) {
        stop("the number of rows of 'data' does not match the length of 'times_data'.")
    }
    spl_times <- split(times_data, fids)

    first_val_zero <- sapply(spl_times[index], "[", 1L) != 0
    spl_times <- lapply(spl_times[index], function (x) if (x[1L] == 0) x else c(0, x))
    ind <- mapply2(findInterval, x = times_to_fill, vec = spl_times,
                   all.inside = first_val_zero)
    rownams_id <- split(row.names(data), fids)
    if (fix_it) {
        rownams_id <- rownams_id[index]
        if (length(rownams_id) < length(ind)) {
              rownams_id <- rep(rownams_id, each = length(ind) / length(rownams_id))
        }
        ind <- mapply2(`[`, rownams_id, ind)
    } else {
      ind <- mapply2(`[`, rownams_id[index], ind)
    }

    data <- data[unlist(ind, use.names = FALSE), ]
    data[[time_var]] <- unlist(times_to_fill, use.names = FALSE)
    row.names(data) <- seq_len(nrow(data))
    data
}

and its corresponding call in prepare_DataE_preds()

dataS_H <- SurvData_HazardModel(split(st, row(st)), dataS, last_times,
                                    paste0(idT, "_", strata), time_var,
                                    rep(index, each = control$GK_k), fix_it = TRUE)

This only fixes the replication issue during prediction, leaving any other call in JMbayes2 untouched. By repeating my above example with this code, I get a very high correlation ($\rho \approx 1)$.

Open questions

I haven't been able to figure out yet why the example in the tutorial is able to achive such a high AUC despite this error. Nevertheless, when applying my fix, I at least see a moderate increase from the $AUC = 0.8088$ reported in the tutorial to $AUC \approx 0.83$.

@drizopoulos
Copy link
Owner

Thanks for the detailed example. I will study it and see what changes will need to be made to the underlying code of predict().

@drizopoulos
Copy link
Owner

I changed the underlying code in the current development version on GitHub. If you have some time, could you check your examples to see if the results are better now?

@prockenschaub
Copy link
Author

I reran my examples and to output looks, agreement between the linear predictor and the predictions is high again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants