Skip to content

Commit

Permalink
continuing metric rework
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Dec 20, 2023
1 parent 554826b commit e8aa6b6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 49 deletions.
102 changes: 53 additions & 49 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,26 @@ def __init__(

def setup_metrics(self):
# these are ignored unless output dim = 2
self.accuracy_metrics = torchmetrics.MetricCollection({
self.accuracy_metrics = torch.nn.ModuleDict({
'train/accuracy': torchmetrics.Accuracy(task='binary'),
'validation/accuracy': torchmetrics.Accuracy(task='binary'),
})

self.val_accuracy = torchmetrics.Accuracy(task='binary')

self.loss_metrics = torchmetrics.MetricCollection({
self.loss_metrics = torch.nn.ModuleDict({
'train/loss': torchmetrics.MeanMetric(nan_strategy='error'),
'validation/loss': torchmetrics.MeanMetric(nan_strategy='error'),
})

# TODO handle when schema doesn't exist
question_metric_dict = {}
for step_name in ['train', 'validation']:
for step_name in ['train', 'validation']: # TODO test
question_metric_dict.update({
step_name + '/question_loss/' + question.text: torchmetrics.MeanMetric(nan_strategy='ignore')
for question in self.schema.questions
})
self.question_loss_metrics = torchmetrics.MetricCollection(question_metric_dict)
self.question_loss_metrics = torch.nn.ModuleDict(question_metric_dict)

campaigns = schema_to_campaigns(self.schema)
campaign_metric_dict = {}
Expand All @@ -88,7 +88,7 @@ def setup_metrics(self):
step_name + '/campaign_loss/' + campaign: torchmetrics.MeanMetric(nan_strategy='ignore')
for campaign in campaigns
})
self.campaign_loss_metrics = torchmetrics.MetricCollection(campaign_metric_dict)
self.campaign_loss_metrics = torch.nn.ModuleDict(campaign_metric_dict)


def forward(self, x):
Expand All @@ -98,43 +98,47 @@ def forward(self, x):
def make_step(self, batch, step_name):
x, labels = batch
predictions = self(x) # by default, these are Dirichlet concentrations
loss = self.calculate_and_log_loss(predictions, labels, step_name)
return {'loss': loss, 'predictions': predictions, 'labels': labels}

def calculate_and_log_loss(self, predictions, labels, step_name):
raise NotImplementedError('Must be subclassed')
loss = self.calculate_loss_and_update_loss_metrics(predictions, labels, step_name)
outputs = {'loss': loss, 'predictions': predictions, 'labels': labels}
self.update_other_metrics(outputs, step_name=step_name)
return outputs

def configure_optimizers(self):
raise NotImplementedError('Must be subclassed')

def training_step(self, batch, batch_idx):
return self.make_step(batch, step_name='train')

def on_train_batch_end(self, outputs, *args):
self.update_metrics(outputs, step_name='train')

def validation_step(self, batch, batch_idx):
return self.make_step(batch, step_name='validation')

def on_validation_batch_end(self, outputs, *args):
self.update_metrics(outputs, step_name='validation')


def test_step(self, batch, batch_idx):
return self.make_step(batch, step_name='test')

def on_test_batch_end(self, outputs, *args):
self.update_metrics(outputs, step_name='test')
# def on_train_batch_end(self, outputs, *args):
# pass

# def on_validation_batch_end(self, outputs, *args):
# pass

def on_train_epoch_end(self) -> None:
self.log_all_metrics(step_name='train')
# called *after* on_validation_epoch_end, confusingly
# do NOT log_all_metrics here.
# logging a metric resets it, and on_validation_epoch_end just logged and reset everything, so you will only log nans
pass

def on_validation_epoch_end(self) -> None:
self.log_all_metrics(step_name='validation')
# raise ValueError('val epoch end')
# called at end of val epoch, but BEFORE on_train_epoch_end
self.log_all_metrics() # logs all metrics, so can do in val only

def update_metrics(self, outputs, step_name):
def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name):
raise NotImplementedError('Must be subclassed')

def update_other_metrics(self, outputs, step_name):
raise NotImplementedError('Must be subclassed')

def log_all_metrics(self, step_name):
def log_all_metrics(self):

self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True)
Expand Down Expand Up @@ -280,14 +284,15 @@ def __init__(
logging.info('Zoobot __init__ complete')


def calculate_and_log_loss(self, predictions, labels, step_name):
def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name):
# self.loss_func returns shape of (galaxy, question), mean to ()
multiq_loss = self.loss_func(predictions, labels, sum_over_questions=False)
self.log_loss_per_question(multiq_loss, prefix=step_name)
self.update_per_question_loss_metric(multiq_loss, step_name=step_name)
# sum over questions and take a per-device mean
# for DDP strategy, batch size is constant (batches are not divided, data pool is divided)
# so this will be the global per-example mean
loss = torch.mean(torch.sum(multiq_loss, axis=1))
self.loss_metrics[step_name + '/loss'](loss)
return loss


Expand All @@ -314,48 +319,47 @@ def configure_optimizers(self):
return optimizer # no scheduler


def update_metrics(self, outputs, step_name):
self.loss_metrics[step_name + '/loss'](outputs['loss'])
def update_other_metrics(self, outputs, step_name):

if outputs['predictions'].shape[1] == 2:
self.accuracy_metrics[step_name + '/accuracy'](outputs['predictions'], torch.argmax(outputs['labels'], dim=1, keepdim=False)),



def log_loss_per_question(self, multiq_loss, prefix):
def update_per_question_loss_metric(self, multiq_loss, step_name):
# log questions individually
# TODO need schema attribute or similar to have access to question names, this will do for now
# unlike Finetuneable..., does not use TorchMetrics, simply logs directly
# TODO could use TorchMetrics and for q in schema, self.q_metric loop

if hasattr(self, 'schema'):
# if hasattr(self, 'schema'):
# use schema metadata to log intelligently
# will have schema if question_answer_pairs and dependencies are passed to __init__
# assume that questions are named like smooth-or-featured-CAMPAIGN
for question_n, question in enumerate(self.schema.questions):
# for logging comparison, want to ignore loss on unlablled examples, i.e. take mean ignoring zeros
# could sum, but then this would vary with batch size
nontrivial_loss_mask = multiq_loss[:, question_n] > 0 # 'zero' seems to be ~5e-5 floor in practice
for question_n, question in enumerate(self.schema.questions):
# for logging comparison, want to ignore loss on unlablled examples, i.e. take mean ignoring zeros
# could sum, but then this would vary with batch size
nontrivial_loss_mask = multiq_loss[:, question_n] > 0 # 'zero' seems to be ~5e-5 floor in practice

this_question_metric = self.question_loss_metrics[prefix + '/question_loss/' + question.text]
this_question_metric(torch.mean(multiq_loss[nontrivial_loss_mask, question_n]))
this_question_metric = self.question_loss_metrics[step_name + '/question_loss/' + question.text]
# raise ValueError
this_question_metric(torch.mean(multiq_loss[nontrivial_loss_mask, question_n]))

campaigns = schema_to_campaigns(self.schema)
for campaign in campaigns:
campaign_questions = [q for q in self.schema.questions if campaign in q.text]
campaign_q_indices = [self.schema.questions.index(q) for q in campaign_questions] # shape (num q in this campaign e.g. 10)
campaigns = schema_to_campaigns(self.schema)
for campaign in campaigns:
campaign_questions = [q for q in self.schema.questions if campaign in q.text]
campaign_q_indices = [self.schema.questions.index(q) for q in campaign_questions] # shape (num q in this campaign e.g. 10)

# similarly to per-question, only include in mean if (any) q in this campaign has a non-trivial loss
nontrivial_loss_mask = multiq_loss[:, campaign_q_indices].sum(axis=1) > 0 # shape batch size
# similarly to per-question, only include in mean if (any) q in this campaign has a non-trivial loss
nontrivial_loss_mask = multiq_loss[:, campaign_q_indices].sum(axis=1) > 0 # shape batch size

this_campaign_metric = self.campaign_loss_metrics[prefix + '/campaign_loss/' + campaign]
this_campaign_metric(torch.mean(multiq_loss[nontrivial_loss_mask][:, campaign_q_indices]))
this_campaign_metric = self.campaign_loss_metrics[step_name + '/campaign_loss/' + campaign]
this_campaign_metric(torch.mean(multiq_loss[nontrivial_loss_mask][:, campaign_q_indices]))

else:
# fallback to logging with question_n
for question_n in range(multiq_loss.shape[1]):
self.log(f'{prefix}/questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True)
# else:
# # fallback to logging with question_n
# for question_n in range(multiq_loss.shape[1]):
# self.log(f'{step_name}/questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True)




Expand Down
1 change: 1 addition & 0 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def train_default_zoobot_from_scratch(
# callbacks = None

trainer = pl.Trainer(
num_sanity_val_steps=0,
log_every_n_steps=150,
accelerator=accelerator,
devices=devices, # per node
Expand Down

0 comments on commit e8aa6b6

Please sign in to comment.