Skip to content

Commit

Permalink
add cosine support
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 1, 2024
1 parent 772660d commit 90c33f5
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions zoobot/pytorch/training/finetune.py
Expand Up @@ -9,6 +9,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -67,7 +68,6 @@ def __init__(
zoobot_checkpoint_loc=None,
# ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later)
encoder=None,
n_epochs=100, # TODO early stopping
n_blocks=0, # how many layers deep to FT
lr_decay=0.75,
weight_decay=0.05,
Expand All @@ -77,7 +77,12 @@ def __init__(
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42,
n_layers=0 # for backward compat., n_blocks preferred
n_layers=0, # for backward compat., n_blocks preferred
# these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already
cosine_schedule=False,
warmup_epochs=10,
max_cosine_epochs=100,
max_learning_rate_reduction_factor=0.01
):
super().__init__()

Expand Down Expand Up @@ -120,7 +125,11 @@ def __init__(
self.lr_decay = lr_decay
self.weight_decay = weight_decay
self.dropout_prob = dropout_prob
self.n_epochs = n_epochs

self.cosine_schedule = cosine_schedule
self.warmup_epochs = warmup_epochs
self.max_cosine_epochs = max_cosine_epochs
self.max_learning_rate_reduction_factor = max_learning_rate_reduction_factor

self.always_train_batchnorm = always_train_batchnorm
if self.always_train_batchnorm:
Expand Down Expand Up @@ -213,8 +222,25 @@ def configure_optimizers(self):
# Initialize AdamW optimizer
opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict

return opt

if self.cosine_schedule:
from lightly.utils.scheduler import CosineWarmupScheduler # new dependency for zoobot, TBD - maybe just copy
return {
"optimizer": opt,
"lr_scheduler": {
"scheduler": CosineWarmupScheduler(
optimizer=opt,
warmup_epochs=self.warmup_epochs,
max_epochs=self.max_cosine_epochs,
start_value=self.learning_rate,
end_value=self.learning_rate * self.max_learning_rate_reduction_factor,
),
'interval': 'epoch',
"frequency": 1
}
}
else:
return opt


def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
Expand Down Expand Up @@ -669,10 +695,12 @@ def get_trainer(
patience=patience
)

learning_rate_monitor_callback = LearningRateMonitor(logging_interval='epoch')

# Initialise pytorch lightning trainer
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
callbacks=[checkpoint_callback, early_stopping_callback, learning_rate_monitor_callback],
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
Expand Down

0 comments on commit 90c33f5

Please sign in to comment.