Skip to content

Commit

Permalink
remove mae
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 5, 2024
1 parent 6b754b9 commit c14637e
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
import torchmetrics as tm
import timm


from foundation.models.mae_lightly import CustomMAEEncoder

from zoobot.pytorch.training import losses
from zoobot.pytorch.estimators import define_model
from zoobot.shared import schemas
Expand Down Expand Up @@ -159,24 +156,26 @@ def configure_optimizers(self):
*batch norm layers may optionally still have updated statistics using always_train_batchnorm
"""


if isinstance(self.encoder, CustomMAEEncoder):
logging.info('Using custom optimizer for MAE encoder')
# equivalent to usual, but in param_group format
head_param_groups = [
{'params': self.head.parameters(),
'weight_decay': self.weight_decay,
'lr_scale': 1. # no lr decay on head
}
]
# now custom bit for the encoder
encoder_param_groups = self.encoder.get_param_groups(self.weight_decay, self.lr_decay)
n_param_groups_to_tune = self.n_blocks * 2 # finetune top N. First layer is pos embedding, then pairs of decay/no decay, 18 pairs by default
if n_param_groups_to_tune > len(encoder_param_groups):
logging.warning('more param groups (blocks*2) specified to finetune than available')
encoder_param_groups_to_tune = encoder_param_groups[-n_param_groups_to_tune:]
param_groups = encoder_param_groups_to_tune + head_param_groups
return torch.optim.AdamW(param_groups, lr=self.learning_rate)
# from foundation.models.mae_lightly import CustomMAEEncoder
# if isinstance(self.encoder, CustomMAEEncoder):
# logging.info('Using custom optimizer for MAE encoder')
# # equivalent to usual, but in param_group format
# head_param_groups = [
# {'params': self.head.parameters(),
# 'weight_decay': self.weight_decay,
# 'lr_scale': 1. # no lr decay on head
# }
# ]
# # now custom bit for the encoder
# encoder_param_groups = self.encoder.get_param_groups(self.weight_decay, self.lr_decay)
# n_param_groups_to_tune = self.n_blocks * 2 # finetune top N. First layer is pos embedding, then pairs of decay/no decay, 18 pairs by default
# if n_param_groups_to_tune > len(encoder_param_groups):
# logging.warning('more param groups (blocks*2) specified to finetune than available')
# encoder_param_groups_to_tune = encoder_param_groups[-n_param_groups_to_tune:]
# param_groups = encoder_param_groups_to_tune + head_param_groups
# return torch.optim.AdamW(param_groups, lr=self.learning_rate)

lr = self.learning_rate
params = [{"params": self.head.parameters(), "lr": lr}]
Expand Down

0 comments on commit c14637e

Please sign in to comment.