Skip to content

Commit

Permalink
minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 8, 2024
1 parent 477c075 commit d185505
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 835 deletions.
7 changes: 5 additions & 2 deletions README.md
Expand Up @@ -133,15 +133,18 @@ CUDA 12.1 for PyTorch 2.1.0:
conda activate zoobot39_torch
conda install -c conda-forge cudatoolkit=12.1


### Recent release features (v2.0.0)

- New pretrained architectures: EfficientnetV2 S/M/L, MaxViT tiny/small/base, ViT tiny/small, and more.
- New pretrained architectures: ConvNeXT, EfficientNetV2, MaxViT, and more. Each in several sizes.
- Reworked finetuning procedure. All these architectures are finetuneable through a common method.
- Reworked finetuning options. Batch norm finetuning removed. Cosine schedule option added.
- Reworked finetuning saving/loading. Auto-downloads encoder from HuggingFace.
- Now supports regression finetuning (as well as multi-class and binary). See `pytorch/examples/finetuning`
- Updated `timm` to 0.9.10, allowing latest model architectures. Previously downloaded checkpoints may not load correctly!
- (internal until published) GZ Evo v2 now includes Cosmic Dawn (HSC H2O). Significant performance improvement on HSC finetuning. Also now includes GZ UKIDSS (dragged from our archives).
- Updated `pytorch` to `2.1.0`
- Added support for webdatasets (only recommended for large-scale distributed training)
- Improved per-question logging when training from scratch
- Added option to compile encoder for max speed (not recommended for finetuning, only for pretraining).
- Deprecates TensorFlow. The CS research community focuses on PyTorch and new frameworks like JAX.

Expand Down
184 changes: 0 additions & 184 deletions zoobot/pytorch/training/debug_split.ipynb

This file was deleted.

83 changes: 3 additions & 80 deletions zoobot/pytorch/training/finetune.py
Expand Up @@ -70,7 +70,7 @@ def __init__(
weight_decay=0.05,
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
always_train_batchnorm=True,
always_train_batchnorm=False, # temporarily deprecated
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42,
Expand Down Expand Up @@ -104,10 +104,6 @@ def __init__(
self.encoder = encoder

# TODO read as encoder property
# if isinstance(self.encoder, CustomMAEEncoder):
# self.encoder_dim = 256 # vit hidden dim, assuming average pool over seq dim
# # self.encoder_dim = 9216
# else:
self.encoder_dim = define_model.get_encoder_dim(self.encoder)

# for backwards compat.
Expand All @@ -129,7 +125,8 @@ def __init__(

self.always_train_batchnorm = always_train_batchnorm
if self.always_train_batchnorm:
logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned')
raise NotImplementedError('Temporarily deprecated, always_train_batchnorm=True not supported')
# logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned')

self.train_loss_metric = tm.MeanMetric()
self.val_loss_metric = tm.MeanMetric()
Expand All @@ -156,26 +153,6 @@ def configure_optimizers(self):
*batch norm layers may optionally still have updated statistics using always_train_batchnorm
"""


# from foundation.models.mae_lightly import CustomMAEEncoder
# if isinstance(self.encoder, CustomMAEEncoder):
# logging.info('Using custom optimizer for MAE encoder')
# # equivalent to usual, but in param_group format
# head_param_groups = [
# {'params': self.head.parameters(),
# 'weight_decay': self.weight_decay,
# 'lr_scale': 1. # no lr decay on head
# }
# ]
# # now custom bit for the encoder
# encoder_param_groups = self.encoder.get_param_groups(self.weight_decay, self.lr_decay)
# n_param_groups_to_tune = self.n_blocks * 2 # finetune top N. First layer is pos embedding, then pairs of decay/no decay, 18 pairs by default
# if n_param_groups_to_tune > len(encoder_param_groups):
# logging.warning('more param groups (blocks*2) specified to finetune than available')
# encoder_param_groups_to_tune = encoder_param_groups[-n_param_groups_to_tune:]
# param_groups = encoder_param_groups_to_tune + head_param_groups
# return torch.optim.AdamW(param_groups, lr=self.learning_rate)

lr = self.learning_rate
params = [{"params": self.head.parameters(), "lr": lr}]
Expand Down Expand Up @@ -753,57 +730,3 @@ def get_trainer(
)

return trainer

# TODO check exactly which layers get FTd
# def is_tuneable(block_of_layers):
# if len(list(block_of_layers.parameters())) == 0:
# logging.info('Skipping block with no params')
# logging.info(block_of_layers)
# return False
# else:
# # currently, allowed to include batchnorm
# return True

# def get_batch_norm_params_lighting(parent_module, checked_params=set(), batch_norm_params=[]):

# modules = parent_module.modules()
# for module in modules:
# if id(module) not in checked_params:
# checked_params.add(id(module))
# if isinstance(module, torch.nn.BatchNorm2d):
# batch_norm_params += module.parameters()
# else:
# checked_params, batch_norm_params = get_batch_norm_params_lighting(module, checked_params, batch_norm_params)

# return checked_params, batch_norm_params



# when ready (don't peek often, you'll overfit)
# trainer.test(model, dataloaders=datamodule)

# return model, checkpoint_callback.best_model_path
# trainer.callbacks[checkpoint_callback].best_model_path?

# def investigate_structure():

# from zoobot.pytorch.estimators import define_model


# model = define_model.get_plain_pytorch_zoobot_model(output_dim=1280, include_top=False)

# # print(model)
# # with include_top=False, first and only child is EffNet
# effnet_with_pool = list(model.children())[0]

# # 0th is actually EffNet, 1st and 2nd are AvgPool and Identity
# effnet = list(effnet_with_pool.children())[0]

# for layer_n, layer in enumerate(effnet.children()):
# # first bunch are Sequential module wrapping e.g. 3 MBConv blocks
# print('\n', layer_n)
# if isinstance(layer, torch.nn.Sequential):
# print(layer)
# # so the blocks to finetune are each Sequential (repeated MBConv) block
# # and other blocks can be left alone
# # (also be careful to leave batch-norm alone)

0 comments on commit d185505

Please sign in to comment.