torch
implementation of Time-to-event (CoxProportionalHazard) models
#353
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This branch was branched off of https://github.com/jamesdolezal/slideflow/tree/fixes . Merging #351 first is advisable so that all changes here are relevant to time-to-event CPH models.
Notes:
Loss
For the torch models the loss for the time-to-event model was called CoxProportionalHazardsLoss this name is more specific than negative_log_likelihood which is used for Tensorflow models.
Changes in TF
Trainer._process_outcome_labels
was created to make the code more consistent with the torch implementationMain implementation
The main crux of implementing CPH models is handling the events indicators. In torch when the model is created its inputs are reduced to - 1 to account for the events indicators which are expected to be the first
input_header
.Importantly, the same is not done for the Trainer.
:class:CPHTrainer(Trainer):
is specific to CPH models so it knows how to handle the firstinput_header
.Since the loss needs the events indicators its input are different depending on the type of Trainer. Because of this
Trainer
andCPHTrainer
implement different versions of the method.:meth:_forward_pass_and_loss
:meth:eval_from_model
Saves now pulls out the events indicators so they can be fed into the loss and saved in the dataframe to compute the metrics.Test
Tests were updates so that CPH models are no longer only tested when backend is TF
One
FIXME:
:If every sample in the batch is censored this line ensures that the loss is 0 and not
nan
.