Skip to content

Commit

Permalink
small ssl tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 3, 2024
1 parent d109603 commit d99d193
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 27 deletions.
28 changes: 4 additions & 24 deletions zoobot/pytorch/estimators/define_model.py
Expand Up @@ -59,17 +59,11 @@ def __init__(


def setup_metrics(self):
# these are ignored unless output dim = 2
self.accuracy_metrics = torch.nn.ModuleDict({
'train/accuracy': torchmetrics.Accuracy(task='binary'),
'validation/accuracy': torchmetrics.Accuracy(task='binary'),
})

self.val_accuracy = torchmetrics.Accuracy(task='binary')

self.loss_metrics = torch.nn.ModuleDict({
'train/loss': torchmetrics.MeanMetric(nan_strategy='error'),
'validation/loss': torchmetrics.MeanMetric(nan_strategy='error'),
'train/supervised_loss': torchmetrics.MeanMetric(nan_strategy='error'),
'validation/supervised_loss': torchmetrics.MeanMetric(nan_strategy='error'),
})

# TODO handle when schema doesn't exist
Expand Down Expand Up @@ -100,7 +94,7 @@ def make_step(self, batch, step_name):
predictions = self(x) # by default, these are Dirichlet concentrations
loss = self.calculate_loss_and_update_loss_metrics(predictions, labels, step_name)
outputs = {'loss': loss, 'predictions': predictions, 'labels': labels}
self.update_other_metrics(outputs, step_name=step_name)
# self.update_other_metrics(outputs, step_name=step_name)
return outputs

def configure_optimizers(self):
Expand Down Expand Up @@ -144,15 +138,6 @@ def log_all_metrics(self):
self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True)
self.log_dict(self.campaign_loss_metrics, on_step=False, on_epoch=True, logger=True)

if hasattr(self, 'accuracy_metrics'):
self.log_dict(
self.accuracy_metrics,
on_epoch=True,
on_step=False,
prog_bar=True,
logger=True
)




Expand Down Expand Up @@ -292,7 +277,7 @@ def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name)
# for DDP strategy, batch size is constant (batches are not divided, data pool is divided)
# so this will be the global per-example mean
loss = torch.mean(torch.sum(multiq_loss, axis=1))
self.loss_metrics[step_name + '/loss'](loss)
self.loss_metrics[step_name + '/supervised_loss'](loss)
return loss


Expand All @@ -319,11 +304,6 @@ def configure_optimizers(self):
return optimizer # no scheduler


def update_other_metrics(self, outputs, step_name):

if outputs['predictions'].shape[1] == 2:
self.accuracy_metrics[step_name + '/accuracy'](outputs['predictions'], torch.argmax(outputs['labels'], dim=1, keepdim=False)),


def update_per_question_loss_metric(self, multiq_loss, step_name):
# log questions individually
Expand Down
2 changes: 1 addition & 1 deletion zoobot/pytorch/training/finetune.py
Expand Up @@ -52,7 +52,7 @@ class FinetuneableZoobotAbstract(pl.LightningModule):
weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.
learning_rate (float, optional): AdamW learning rate arg. Defaults to 1e-4.
dropout_prob (float, optional): P of dropout before final output layer. Defaults to 0.5.
freeze_batchnorm (bool, optional): If True, do not update batchnorm stats during finetuning. Defaults to True.
always_train_batchnorm (bool, optional): If True, do not update batchnorm stats during finetuning. Defaults to True.
prog_bar (bool, optional): Print progress bar during finetuning. Defaults to True.
visualize_images (bool, optional): Upload example images to WandB. Good for debugging but slow. Defaults to False.
seed (int, optional): random seed to use. Defaults to 42.
Expand Down
1 change: 0 additions & 1 deletion zoobot/pytorch/training/losses.py
Expand Up @@ -26,7 +26,6 @@ def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor
q_indices = question_index_groups[q_n]
q_start = q_indices[0]
q_end = q_indices[1]

q_loss = dirichlet_loss(labels[:, q_start:q_end+1], predictions[:, q_start:q_end+1])

q_losses.append(q_loss)
Expand Down
2 changes: 1 addition & 1 deletion zoobot/pytorch/training/train_with_pytorch_lightning.py
Expand Up @@ -277,7 +277,7 @@ def train_default_zoobot_from_scratch(

extra_callbacks = extra_callbacks if extra_callbacks else []

monitor_metric = 'validation/loss'
monitor_metric = 'validation/supervised_loss'

# used later for checkpoint_callback.best_model_path
checkpoint_callback = ModelCheckpoint(
Expand Down

0 comments on commit d99d193

Please sign in to comment.