Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest finetuning changes #113

Merged
merged 13 commits into from
Mar 21, 2024
2 changes: 1 addition & 1 deletion .github/workflows/run_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.8", "3.9"] # zoobot should support these (many academics not on 3.9)
python-version: ["3.9"] # zoobot should support these
experimental: [false]
include:
- python-version: "3.10" # test the next python version but allow it to fail
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ hparams.yaml

data/pretrained_models

*.tar
*.tar
*.ckpt
14 changes: 0 additions & 14 deletions Dockerfile.tf

This file was deleted.

11 changes: 0 additions & 11 deletions docker-compose-tf.yml

This file was deleted.

6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"Environment :: GPU :: NVIDIA CUDA"
],
packages=setuptools.find_packages(),
python_requires=">=3.8", # recommend 3.9 for new users. TF needs >=3.7.2, torchvision>=3.8
python_requires=">=3.9", # bumped to 3.9 for typing
extras_require={
'pytorch-cpu': [
# A100 GPU currently only seems to support cuda 11.3 on manchester cluster, let's stick with this version for now
Expand Down Expand Up @@ -112,7 +112,9 @@
'pyarrow', # to read parquet, which is very handy for big datasets
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
'webdataset', # for reading webdataset files
'huggingface_hub', # login may be required
'setuptools', # no longer pinned
'galaxy-datasets>=0.0.15' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
'galaxy-datasets>=0.0.17' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
]
)
3 changes: 2 additions & 1 deletion tests/pytorch/test_define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def schema():
def test_ZoobotTree_init(schema):
model = define_model.ZoobotTree(
output_dim=12,
question_index_groups=schema.question_index_groups,
question_answer_pairs=schema.question_answer_pairs,
dependencies=schema.dependencies
)

43 changes: 43 additions & 0 deletions tests/test_from_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

import timm
import torch


def test_get_encoder():
model = timm.create_model("hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0", pretrained=True)
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 1280)


def test_get_finetuned():
# checkpoint_loc = 'https://huggingface.co/mwalmsley/zoobot-finetuned-is_tidal/resolve/main/3.ckpt' pickle problem via lightning
# checkpoint_loc = '/home/walml/Downloads/3.ckpt' # works when downloaded manually

from huggingface_hub import hf_hub_download

REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal"
FILENAME = "FinetuneableZoobotClassifier.ckpt"

downloaded_loc = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
)
from zoobot.pytorch.training import finetune
model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)

def test_get_finetuned_class_method():

from zoobot.pytorch.training import finetune

model = finetune.FinetuneableZoobotClassifier.load_from_name('mwalmsley/zoobot-finetuned-is_tidal', map_location='cpu')
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)

# def test_get_finetuned_from_local():
# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt'
# checkpoint_loc = '/home/walml/repos/zoobot-foundation/results/finetune/is-lsb/debug/checkpoints/4.ckpt'

# from zoobot.pytorch.training import finetune
# # if originally trained with a direct in-memory checkpoint, must specify the hub name manually. otherwise it's saved as an hparam.
# model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(checkpoint_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', )
# assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)
4 changes: 2 additions & 2 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def make_image_transform(self, mode="train"):
crop_ratio_bounds=self.crop_ratio_bounds,
resize_after_crop=self.resize_after_crop,
pytorch_greyscale=not self.color,
to_float=True # wrong, webdataset rgb decoder already converts to 0-1 float
# TODO this must be changed! will be different for new model training runs
to_float=False # True was wrong, webdataset rgb decoder already converts to 0-1 float
# TODO now changed on dev branch will be different for new model training runs
) # A.Compose object

# logging.warning('Minimal augmentations for speed test')
Expand Down
2 changes: 2 additions & 0 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse
# in augs that could be 0.x-1.0, and here a pre-crop to 0.8 i.e. 340px
# but this would change the centering
# let's stick to small boundary crop and 0.75-0.85 in augs

# turn these off for current euclidized images, already 300x300
A.CenterCrop(
height=400,
width=400,
Expand Down
22 changes: 11 additions & 11 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ class ZoobotTree(GenericLightningModule):

Args:
output_dim (int): Output dimension of model's head e.g. 34 for predicting a 34-answer decision tree.
question_index_groups (List): Mapping of which label indices are part of the same question. See :ref:`training_on_vote_counts`.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to "efficientnet_b0".
channels (int, optional): Num. input channels. Probably 3 or 1. Defaults to 1.
test_time_dropout (bool, optional): Apply dropout at test time, to pretend to be Bayesian. Defaults to True.
Expand All @@ -192,7 +191,7 @@ def __init__(
self,
output_dim: int,
# in the simplest case, this is all zoobot needs: grouping of label col indices as questions
question_index_groups: List=None,
# question_index_groups: List=None,
# BUT
# if you pass these, it enables better per-question and per-survey logging (because we have names)
# must be passed as simple dicts, not objects, so can't just pass schema in
Expand All @@ -219,7 +218,6 @@ def __init__(
super().__init__(
# these all do nothing, they are simply saved by lightning as hparams
output_dim,
question_index_groups,
question_answer_pairs,
dependencies,
architecture_name,
Expand All @@ -236,13 +234,12 @@ def __init__(

logging.info('Generic __init__ complete - moving to Zoobot __init__')

if question_answer_pairs is not None:
logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
question_index_groups = self.schema.question_index_groups
# logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
question_index_groups = self.schema.question_index_groups

self.setup_metrics()

Expand Down Expand Up @@ -480,4 +477,7 @@ def schema_to_campaigns(schema):
if __name__ == '__main__':
encoder = get_pytorch_encoder(channels=1)
dim = get_encoder_dim(encoder, channels=1)
print(dim)
print(dim)


ZoobotTree.load_from_checkpoint
96 changes: 76 additions & 20 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ class FinetuneableZoobotAbstract(pl.LightningModule):
Both :class:`FinetuneableZoobotClassifier` and :class:`FinetuneableZoobotTree`
can (and should) be passed any of these arguments to customise finetuning.

You could subclass this class to solve new finetuning tasks (like regression) - see :ref:`advanced_finetuning`.
Any FinetuneableZoobot model can be loaded in one of three ways:
- HuggingFace name e.g. FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, ...)
- ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)

You could subclass this class to solve new finetuning tasks - see :ref:`advanced_finetuning`.

Args:
checkpoint_loc (str, optional): Path to encoder checkpoint to load (likely a saved ZoobotTree). Defaults to None.
encoder (pl.LightningModule, optional): Alternatively, pass an encoder directly. Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`.
name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None.
encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory
zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`. Defaults to None.
encoder_dim (int, optional): Output dimension of encoder. Defaults to 1280 (EfficientNetB0's encoder dim).
lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.
weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05.
Expand All @@ -61,25 +67,39 @@ class FinetuneableZoobotAbstract(pl.LightningModule):

def __init__(
self,
# can provide either zoobot_checkpoint_loc, and will load this model as encoder...
zoobot_checkpoint_loc=None,

# load a pretrained timm encoder saved on huggingface hub
# (aimed at most users, easiest way to load published models)
name=None,

# ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later)
encoder=None,
# (aimed at tinkering with new architectures e.g. SSL)
encoder=None, # use any torch model already loaded in memory (must have .forward() method)

# load a pretrained zoobottree model and grab the encoder (a timm model)
# requires the exact same zoobot version used for training, not very portable
# (aimed at supervised experiments)
zoobot_checkpoint_loc=None,

# finetuning settings
n_blocks=0, # how many layers deep to FT
lr_decay=0.75,
weight_decay=0.05,
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
always_train_batchnorm=False, # temporarily deprecated
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42,
n_layers=0, # for backward compat., n_blocks preferred
# these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already
cosine_schedule=False,
warmup_epochs=10,
max_cosine_epochs=100,
max_learning_rate_reduction_factor=0.01
max_learning_rate_reduction_factor=0.01,
# escape hatch for 'from scratch' baselines
from_scratch=False,
# debugging utils
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
seed=42
):
super().__init__()

Expand All @@ -94,17 +114,22 @@ def __init__(
self.save_hyperparameters(ignore=['encoder']) # never serialise the encoder, way too heavy
# if you need the encoder to recreate, pass when loading checkpoint e.g.
# FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder)

if zoobot_checkpoint_loc is not None:
assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc)

if name is not None:
assert encoder is None, 'Cannot pass both name and encoder to use'
self.encoder = timm.create_model(name, pretrained=True)
self.encoder_dim = self.encoder.num_features

elif zoobot_checkpoint_loc is not None:
assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use'
self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) # extracts the timm encoder
self.encoder_dim = self.encoder.num_features
else:
assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder

# TODO read as encoder property
self.encoder_dim = define_model.get_encoder_dim(self.encoder)
assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use'
assert encoder is not None, 'Must pass either checkpoint to load or encoder to use'
self.encoder = encoder
# work out encoder dim 'manually'
self.encoder_dim = define_model.get_encoder_dim(self.encoder)

# for backwards compat.
if n_layers:
Expand All @@ -123,6 +148,8 @@ def __init__(
self.max_cosine_epochs = max_cosine_epochs
self.max_learning_rate_reduction_factor = max_learning_rate_reduction_factor

self.from_scratch = from_scratch

self.always_train_batchnorm = always_train_batchnorm
if self.always_train_batchnorm:
raise NotImplementedError('Temporarily deprecated, always_train_batchnorm=True not supported')
Expand Down Expand Up @@ -159,6 +186,11 @@ def configure_optimizers(self):

logging.info(f'Encoder architecture to finetune: {type(self.encoder)}')

if self.from_scratch:
logging.warning('self.from_scratch is True, training everything and ignoring all settings')
params += [{"params": self.encoder.parameters(), "lr": lr}]
return torch.optim.AdamW(params, weight_decay=self.weight_decay)

if isinstance(self.encoder, timm.models.EfficientNet): # includes v2
# TODO for now, these count as separate layers, not ideal
early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]
Expand Down Expand Up @@ -345,6 +377,13 @@ def on_test_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx

def upload_images_to_wandb(self, outputs, batch, batch_idx):
raise NotImplementedError('Must be subclassed')

@classmethod
def load_from_name(cls, name: str, **kwargs):
downloaded_loc = download_from_name(cls.__name__, name, **kwargs)
return cls.load_from_checkpoint(downloaded_loc, **kwargs) # trained on GPU, may need map_location='cpu' if you get a device error





Expand All @@ -364,6 +403,8 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract):

"""



def __init__(
self,
num_classes: int,
Expand Down Expand Up @@ -730,3 +771,18 @@ def get_trainer(
)

return trainer


def download_from_name(class_name: str, hub_name: str, **kwargs):
from huggingface_hub import hf_hub_download

if hub_name.startswith('hf_hub:'):
logging.info('Passed name with hf_hub: prefix, dropping prefix')
repo_id = hub_name.split('hf_hub:')[1]
else:
repo_id = hub_name
downloaded_loc = hf_hub_download(
repo_id=repo_id,
filename=f"{class_name}.ckpt"
)
return downloaded_loc
28 changes: 21 additions & 7 deletions zoobot/pytorch/training/representations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import logging
import pytorch_lightning as pl

from timm import create_model


class ZoobotEncoder(pl.LightningModule):
# very simple wrapper to turn pytorch model into lightning module
# useful when we want to use lightning to make predictions with our encoder
# (i.e. to get representations)

def __init__(self, encoder, pyramid=False) -> None:
super().__init__()
def __init__(self, encoder):
logging.info('ZoobotEncoder: using provided in-memory encoder')
self.encoder = encoder # plain pytorch module e.g. Sequential
if pyramid:
raise NotImplementedError('Will eventually support resetting timm classifier to get FPN features')


def forward(self, x):
if isinstance(x, list) and len(x) == 1:
return self(x[0])
return self.encoder(x)

@classmethod
def load_from_name(cls, name: str):
"""
e.g. ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
Args:
name (str): huggingface hub name to load

Returns:
nn.Module: timm model
"""
timm_model = create_model(name)
return cls(timm_model)