Skip to content

Commit

Permalink
careful with nans
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 29, 2024
1 parent b652165 commit 938d7ca
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 49 deletions.
2 changes: 1 addition & 1 deletion zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals

# multiquestion_loss returns loss of shape (batch, question)
# torch.sum(multiquestion_loss, axis=1) gives loss of shape (batch). Equiv. to non-log product of question likelihoods.
multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups)
multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups, careful=True)
if sum_over_questions:
return torch.sum(multiq_loss, axis=1)
else:
Expand Down
114 changes: 109 additions & 5 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __init__(
super().__init__(**super_kwargs)

logging.info('Using classification head and cross-entropy loss')
self.head = LinearClassifier(
self.head = LinearHead(
input_dim=self.encoder_dim,
output_dim=num_classes,
dropout_prob=self.dropout_prob
Expand Down Expand Up @@ -387,7 +387,7 @@ def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx):
# see Abstract version
if isinstance(x, list) and len(x) == 1:
return self(x[0])
x = self.forward(x) # type: ignore # logits from LinearClassifier
x = self.forward(x) # type: ignore # logits from LinearHead
# then applies softmax
return F.softmax(x, dim=1)

Expand All @@ -407,6 +407,98 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx):
caption=captions)



class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract):
"""
Pretrained Zoobot model intended for finetuning on a regression problem.
See FinetuneableZoobotClassifier, above
Args:
None besides those from FinetuneableZoobotAbstract, above (1 class, MSE error, for now)
"""

def __init__(
self,
**super_kwargs) -> None:

super().__init__(**super_kwargs)

logging.info('Using classification head and cross-entropy loss')
self.head = LinearHead(
input_dim=self.encoder_dim,
output_dim=1,
dropout_prob=self.dropout_prob
)
self.loss = mse_loss
# rmse metrics. loss is mse already.
self.train_rmse = tm.MeanSquaredError(squared=False)
self.val_rmse = tm.MeanSquaredError(squared=False)
self.test_rmse = tm.MeanSquaredError(squared=False)

def step_to_dict(self, y, y_pred, loss):
return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y}

def on_train_batch_end(self, step_output, *args):
super().on_train_batch_end(step_output, *args)

self.train_rmse(step_output['predictions'], step_output['labels'])
self.log(
'finetuning/train_rmse',
self.train_rmse,
on_step=False,
on_epoch=True,
prog_bar=self.prog_bar
)

def on_validation_batch_end(self, step_output, *args):
super().on_validation_batch_end(step_output, *args)

self.val_rmse(step_output['predictions'], step_output['labels'])
self.log(
'finetuning/val_rmse',
self.val_rmse,
on_step=False,
on_epoch=True,
prog_bar=self.prog_bar
)

def on_test_batch_end(self, step_output, *args) -> None:
super().on_test_batch_end(step_output, *args)

self.test_rmse(step_output['predictions'], step_output['labels'])
self.log(
"finetuning/test_rmse",
self.test_rmse,
on_step=False,
on_epoch=True,
prog_bar=self.prog_bar
)


def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx):
# see Abstract version
if isinstance(x, list) and len(x) == 1:
return self(x[0])
return self.forward(x)

# TODO
# def upload_images_to_wandb(self, outputs, batch, batch_idx):
# # self.logger is set by pl.Trainer(logger=) argument
# if (self.logger is not None) and (batch_idx == 0):
# x, y = batch
# y_pred_softmax = F.softmax(outputs['predictions'], dim=1)
# n_images = 5
# images = [img for img in x[:n_images]]
# captions = [f'Ground Truth: {y_i} \nPrediction: {y_p_i}' for y_i, y_p_i in zip(
# y[:n_images], y_pred_softmax[:n_images])]
# self.logger.log_image( # type: ignore
# key='val_images',
# images=images,
# caption=captions)


class FinetuneableZoobotTree(FinetuneableZoobotAbstract):
"""
Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem.
Expand Down Expand Up @@ -447,18 +539,23 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx):
# other functions are simply inherited from FinetunedZoobotAbstract

# https://github.com/inigoval/byol/blob/1da1bba7dc5cabe2b47956f9d7c6277decd16cc7/byol_main/networks/models.py#L29
class LinearClassifier(torch.nn.Module):
class LinearHead(torch.nn.Module):
def __init__(self, input_dim, output_dim, dropout_prob=0.5):
# input dim is representation dim, output_dim is num classes
super(LinearClassifier, self).__init__()
super(LinearHead, self).__init__()
self.output_dim = output_dim
self.dropout = torch.nn.Dropout(p=dropout_prob)
self.linear = torch.nn.Linear(input_dim, output_dim)

def forward(self, x):
# returns logits, as recommended for CrossEntropy loss
x = self.dropout(x)
x = self.linear(x)
return x
if self.output_dim == 1:
return x.squeeze()
else:
return x



def cross_entropy_loss(y_pred, y, label_smoothing=0., weight=None):
Expand All @@ -468,6 +565,13 @@ def cross_entropy_loss(y_pred, y, label_smoothing=0., weight=None):
# will reduce myself
return F.cross_entropy(y_pred, y.long(), label_smoothing=label_smoothing, weight=weight, reduction='none')

def mse_loss(y_pred, y):
# y should be shape (batch) and ints
# y_pred should be shape (batch, classes)
# returns loss of shape (batch)
# will reduce myself
return F.mse_loss(y_pred, y, reduction='none')


def dirichlet_loss(y_pred, y, question_index_groups):
# aggregation equiv. to sum(axis=1).mean(), but fewer operations
Expand Down
16 changes: 14 additions & 2 deletions zoobot/pytorch/training/losses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Tuple
import logging

import torch
import pyro


def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor, question_index_groups: Tuple) -> torch.Tensor:
def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor, question_index_groups: Tuple, careful=True) -> torch.Tensor:
"""
The full decision tree loss used for training GZ DECaLS models
Expand All @@ -19,6 +20,16 @@ def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor
Returns:
torch.Tensor: neg. log likelihood of shape (batch, question).
"""
if careful:
# some models give occasional nans for all predictions on a specific galaxy/row
# these are inputs to the loss and only happen many epochs in so probably not a case of bad labels, but rather some instability during training
# handle this by setting loss=0 for those rows and throwing a warning
nan_prediction = torch.isnan(predictions) | torch.isinf(predictions)
if nan_prediction.any():
logging.warning(f'Found nan values in predictions: {predictions}')
safety_value = torch.ones(1, device=predictions.device, dtype=predictions.dtype) # fill with 1 everywhere i.e. fully uncertain
predictions = torch.where(condition=nan_prediction, input=safety_value, other=predictions)

# very important that question_index_groups is fixed and discrete, else tf.function autograph will mess up
q_losses = []
# will give shape errors if model output dim is not labels dim, which can happen if losses.py substrings are missing an answer
Expand Down Expand Up @@ -104,5 +115,6 @@ def dirichlet_loss(labels_for_q, concentrations_for_q):
def get_dirichlet_neg_log_prob(labels_for_q, total_count, concentrations_for_q):
# https://docs.pyro.ai/en/stable/distributions.html#dirichletmultinomial
# .int()s avoid rounding errors causing loss of around 1e-5 for questions with 0 votes
dist = pyro.distributions.DirichletMultinomial(total_count=total_count.int(), concentration=concentrations_for_q, is_sparse=False, validate_args=False)
dist = pyro.distributions.DirichletMultinomial(
total_count=total_count.int(), concentration=concentrations_for_q, is_sparse=False, validate_args=True)
return -dist.log_prob(labels_for_q.int()) # important minus sign
43 changes: 2 additions & 41 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,11 @@ def train_default_zoobot_from_scratch(
# these args are automatically logged
lightning_model = define_model.ZoobotTree(
output_dim=len(schema.label_cols),
# question_index_groups=schema.question_index_groups,
# NEW - pass these from schema, for better logging
question_answer_pairs=schema.question_answer_pairs,
dependencies=schema.dependencies,
architecture_name=architecture_name,
channels=channels,
# use_imagenet_weights=False,
test_time_dropout=True,
dropout_rate=dropout_rate,
learning_rate=learning_rate,
Expand Down Expand Up @@ -306,7 +304,6 @@ def train_default_zoobot_from_scratch(

early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True)
callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks
# callbacks = None

trainer = pl.Trainer(
num_sanity_val_steps=0,
Expand All @@ -321,14 +318,9 @@ def train_default_zoobot_from_scratch(
max_epochs=epochs,
default_root_dir=save_dir,
plugins=plugins,
gradient_clip_val=1. # new, for large models
# ,
# limit_train_batches=1,
# limit_val_batches=1
# use_distributed_sampler=use_distributed_sampler
gradient_clip_val=.3 # reduced from 1 to .3, having some nan issues
)


trainer.fit(lightning_model, datamodule) # uses batch size of datamodule

best_model_path = trainer.checkpoint_callback.best_model_path
Expand All @@ -337,44 +329,13 @@ def train_default_zoobot_from_scratch(
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# # test_trainer.validate(
# # model=lightning_model,
# # datamodule=datamodule,
# # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# # )
# test_trainer = pl.Trainer(
# accelerator=accelerator,
# devices=1,
# precision=precision,
# logger=wandb_logger,
# default_root_dir=save_dir
# )
# if test_trainer.is_global_zero:
# test_datamodule = webdatamodule.WebDataModule(
# train_urls=None,
# val_urls=None,
# test_urls=test_urls,
# label_cols=schema.label_cols,
# batch_size=batch_size,
# num_workers=1, # 20 / 5 == 4, /2=2
# prefetch_factor=prefetch_factor,
# cache_dir=None,
# color=color,
# crop_scale_bounds=crop_scale_bounds,
# crop_ratio_bounds=crop_ratio_bounds,
# resize_after_crop=resize_after_crop
# )
datamodule.setup(stage='test')
# TODO with webdataset, no need for new trainer/datamodule (actually it breaks), but might still be needed with normal dataset?
trainer.test(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# else:
# logging.info('Not global zero, skipping test metrics')
# else:
# logging.info('No test dataloader found, skipping test metrics')


# explicitly update the model weights to the best checkpoint before returning
# (assumes only one checkpoint callback, very likely in practice)
Expand Down

0 comments on commit 938d7ca

Please sign in to comment.