Skip to content

Commit

Permalink
train effnetv2xl on evo
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 19, 2024
1 parent c8442db commit b652165
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import torchmetrics as tm
import timm


from foundation.models.mae_lightly import CustomMAEEncoder

from zoobot.pytorch.training import losses
from zoobot.pytorch.estimators import define_model
from zoobot.shared import schemas
Expand Down Expand Up @@ -98,7 +101,12 @@ def __init__(
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder

self.encoder_dim = define_model.get_encoder_dim(self.encoder)
# TODO read as encoder property
if isinstance(self.encoder, CustomMAEEncoder):
self.encoder_dim = 256 # vit hidden dim, assuming average pool over seq dim
# self.encoder_dim = 9216
else:
self.encoder_dim = define_model.get_encoder_dim(self.encoder)
self.n_blocks = n_blocks

# for backwards compat.
Expand All @@ -125,15 +133,31 @@ def __init__(
self.prog_bar = prog_bar
self.visualize_images = visualize_images

def configure_optimizers(self):
def configure_optimizers(self):

if isinstance(self.encoder, CustomMAEEncoder):
logging.info('Using custom optimizer for MAE encoder')
# equivalent to usual, but in param_group format
head_param_groups = [
{'params': self.head.parameters(),
'weight_decay': self.weight_decay,
'lr_scale': 1. # no lr decay on head
}
]
# now custom bit for the encoder
encoder_param_groups = self.encoder.get_param_groups(self.weight_decay, self.lr_decay)
n_param_groups_to_tune = self.n_blocks * 2 # finetune top N. First layer is pos embedding, then pairs of decay/no decay, 18 pairs by default
if n_param_groups_to_tune > len(encoder_param_groups):
logging.warning('more param groups (blocks*2) specified to finetune than available')
encoder_param_groups_to_tune = encoder_param_groups[-n_param_groups_to_tune:]
param_groups = encoder_param_groups_to_tune + head_param_groups
return torch.optim.AdamW(param_groups, lr=self.learning_rate)

lr = self.learning_rate
params = [{"params": self.head.parameters(), "lr": lr}]

# architecture = self.encoder.default_config['architecture']
logging.info(f'Encoder architecture to finetune: {type(self.encoder)}')

# if 'efficientnet' in architecture:
if isinstance(self.encoder, timm.models.EfficientNet):
# TODO for now, these count as separate layers, not ideal
early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]
Expand All @@ -152,16 +176,9 @@ def configure_optimizers(self):
]
elif isinstance(self.encoder, timm.models.MaxxVit):
blocks_to_tune = self.encoder.stem + [stage for stage in self.encoder.stages]
# [
# getattr as obj.0 is not allowed (why does timm call them 0!?)
# getattr(self.encoder.stages, '0'),
# getattr(self.encoder.stages, '1'),
# getattr(self.encoder.stages, '2'),
# getattr(self.encoder.stages, '3'),
# ]
else:
raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}')

assert self.n_blocks <= len(
blocks_to_tune
), f"Network only has {len(blocks_to_tune)} tuneable blocks, {self.n_blocks} specified for finetuning"
Expand All @@ -181,8 +198,6 @@ def configure_optimizers(self):
"lr": lr * (self.lr_decay**i)
})

logging.debug(params)

# optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers
for i, block in enumerate(remaining_blocks):
if self.always_train_batchnorm:
Expand All @@ -200,6 +215,7 @@ def configure_optimizers(self):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.head(x)
# TODO encoder output shape changes with input shape (of course) so need to specify explicitly or skip
return x

def make_step(self, batch):
Expand Down

0 comments on commit b652165

Please sign in to comment.