Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch implementation of Time-to-event (CoxProportionalHazard) models #353

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from

Conversation

matte-esse
Copy link
Collaborator

@matte-esse matte-esse commented Mar 26, 2024

This branch was branched off of https://github.com/jamesdolezal/slideflow/tree/fixes . Merging #351 first is advisable so that all changes here are relevant to time-to-event CPH models.

Notes:

Loss

For the torch models the loss for the time-to-event model was called CoxProportionalHazardsLoss this name is more specific than negative_log_likelihood which is used for Tensorflow models.

Changes in TF

Trainer._process_outcome_labels was created to make the code more consistent with the torch implementation

Main implementation

The main crux of implementing CPH models is handling the events indicators. In torch when the model is created its inputs are reduced to - 1 to account for the events indicators which are expected to be the first input_header.

Importantly, the same is not done for the Trainer. :class:CPHTrainer(Trainer): is specific to CPH models so it knows how to handle the first input_header.

Since the loss needs the events indicators its input are different depending on the type of Trainer. Because of this Trainer and CPHTrainer implement different versions of the method. :meth:_forward_pass_and_loss

:meth:eval_from_model Saves now pulls out the events indicators so they can be fed into the loss and saved in the dataframe to compute the metrics.

Test

Tests were updates so that CPH models are no longer only tested when backend is TF

One FIXME::

If every sample in the batch is censored this line ensures that the loss is 0 and not nan.

@matte-esse matte-esse changed the title Implementation of CoxProportionalHazard torch implementation of CoxProportionalHazard Mar 26, 2024
@matte-esse matte-esse changed the title torch implementation of CoxProportionalHazard torch implementation of Time-to-event (CoxProportionalHazard) models Mar 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant