Skip to content

Commit

Permalink
Merge pull request #106 from mwalmsley/dev
Browse files Browse the repository at this point in the history
Add multiclass support
  • Loading branch information
mwalmsley committed Aug 1, 2023
2 parents c3fc04b + 0c3fb82 commit ad78db8
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 24 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Expand Up @@ -163,4 +163,6 @@ data/cosmic_dawn*.parquet

results

hparams.yaml
hparams.yaml

data/pretrained_models
9 changes: 8 additions & 1 deletion README.md
Expand Up @@ -149,7 +149,14 @@ CUDA 11.2 and CUDNN 8.1 for TensorFlow 2.10.0:
conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/ # add this environment variable

### Latest features (v1.0.0)
### Latest minor features (v1.0.4)

- Now supports multi-class finetuning. See `pytorch/examples/finetuning/finetune_multiclass_classification.py`
- Removed `simplejpeg` dependency due to M1 install issue.
- Pinned `timm` version to ensure MaX-ViT models load correctly. Models supporting the latest `timm` will follow.
- (internal until published) GZ Evo v2 now includes Cosmic Dawn (HSC). Significant performance improvement on HSC finetuning.

### Latest major features (v1.0.0)

v1.0.0 recognises that most of the complexity in this repo is training Zoobot from scratch, but most non-GZ users will probably simply want to load the pretrained Zoobot and finetune it on their data.

Expand Down
9 changes: 5 additions & 4 deletions benchmarks/pytorch/run_benchmarks.sh
Expand Up @@ -16,8 +16,9 @@ SEED=$RANDOM
# effnet, greyscale and color
# sbatch --job-name=evo_py_gr_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_gr_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=128,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_co_eff_224_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,BATCH_SIZE=128,RESIZE_AFTER_CROP=300,DATASET=gz_evo,COLOR_STRING=--color,GPUS=2,SEED=$SEED $TRAIN_JOB

# and resnet18
# sbatch --job-name=evo_py_gr_res18_224_$SEED --export=ARCHITECTURE=resnet18,BATCH_SIZE=256,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_gr_res18_300_$SEED --export=ARCHITECTURE=resnet18,BATCH_SIZE=256,RESIZE_AFTER_CROP=300,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
Expand All @@ -27,8 +28,8 @@ sbatch --job-name=evo_py_co_eff_300_$SEED --export=ARCHITECTURE=efficientnet_b0,
# and with max-vit tiny because hey transformers are cool

# smaller batch size due to memory
# sbatch --job-name=evo_py_gr_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
# sbatch --job-name=evo_py_co_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_gr_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
sbatch --job-name=evo_py_co_vittiny_224_$SEED --export=ARCHITECTURE=maxvit_tiny_224,BATCH_SIZE=128,RESIZE_AFTER_CROP=224,DATASET=gz_evo,COLOR_STRING=--color,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB

# and max-vit small (works badly)
# sbatch --job-name=evo_py_gr_vitsmall_224_$SEED --export=ARCHITECTURE=maxvit_small_224,BATCH_SIZE=64,RESIZE_AFTER_CROP=224,DATASET=gz_evo,MIXED_PRECISION_STRING=--mixed-precision,GPUS=2,SEED=$SEED $TRAIN_JOB
Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Expand Up @@ -9,7 +9,6 @@ torch == 1.10.1
torchvision == 0.11.2
torchaudio == 0.10.1
pytorch-lightning==1.6.5 # 1.7 requires protobuf version incompatible with tensorflow/tensorboard. Otherwise works.
simplejpeg
albumentations
pyro-ppl == 1.8.0
pytorch-galaxy-datasets == 0.0.1
Expand Down
17 changes: 7 additions & 10 deletions setup.py
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="zoobot",
version="1.0.3",
version="1.0.4",
author="Mike Walmsley",
author_email="walmsleymk1@gmail.com",
description="Galaxy morphology classifiers",
Expand All @@ -29,11 +29,11 @@
'torchvision == 0.13.1+cpu',
'torchaudio == 0.12.1',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
# 'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
'timm == 0.6.12'
],
'pytorch_m1': [
# as above but without the +cpu (and the extra-index-url in readme has no effect)
Expand All @@ -42,11 +42,10 @@
'torchvision == 0.13.1',
'torchaudio == 0.12.1',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
'timm == 0.6.12'
],
# as above but without pytorch itself
# for GPU, you will also need e.g. cudatoolkit=11.3, 11.6
Expand All @@ -56,19 +55,17 @@
'torchvision == 0.13.1+cu113',
'torchaudio == 0.12.1',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
'timm == 0.6.12'
],
'pytorch_colab': [
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl>=1.8.0',
'torchmetrics==0.11.0',
'timm'
'timm == 0.6.12'
],
'tensorflow': [
'tensorflow == 2.10.0', # 2.11.0 turns on XLA somewhere which then fails on multi-GPU...TODO
Expand Down Expand Up @@ -105,6 +102,6 @@
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
'setuptools==59.5.0', # wandb logger incompatibility
'galaxy-datasets==0.0.12' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets)
'galaxy-datasets==0.0.14' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets)
]
)
@@ -0,0 +1,94 @@
import logging
import os

from zoobot.pytorch.training import finetune
from galaxy_datasets import demo_rings
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule


if __name__ == '__main__':

logging.basicConfig(level=logging.INFO)

zoobot_dir = '/Users/user/repos/zoobot' # TODO set to directory where you cloned Zoobot

# load in catalogs of images and labels to finetune on
# each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels
# here I'm using galaxy-datasets to download some premade data - check it out for examples
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings' # TODO set to any directory. rings dataset will be downloaded here
train_catalog, _ = demo_rings(root=data_dir, download=True, train=True)
test_catalog, _ = demo_rings(root=data_dir, download=True, train=False)

# wondering about "label_cols"?
# This is a list of catalog columns which should be used as labels
# Here:
# TODO should use Galaxy MNIST as my example here
label_cols = ['ring']
# For binary classification, the label column should have binary (0 or 1) labels for your classes
import numpy as np
# 0, 1, 2
train_catalog['ring'] = np.random.randint(low=0, high=3, size=len(train_catalog))

# TODO
# To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine.

# load a pretrained checkpoint saved here
checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt')
# checkpoint_loc = '/Users/user/repos/gz-decals-classifiers/results/benchmarks/pytorch/dr5/dr5_py_gr_15366/checkpoints/epoch=58-step=18939.ckpt'

# save the finetuning results here
save_dir = os.path.join(zoobot_dir, 'results/pytorch/finetune/finetune_multiclass_classification')

datamodule = GalaxyDataModule(
label_cols=label_cols,
catalog=train_catalog, # very small, as a demo
batch_size=32
)
# datamodule.setup()
# for images, labels in datamodule.train_dataloader():
# print(images.shape)
# print(labels.shape)
# exit()


model = finetune.FinetuneableZoobotClassifier(
checkpoint_loc=checkpoint_loc,
num_classes=3,
n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper.
)
# under the hood, this does:
# encoder = finetune.load_pretrained_encoder(checkpoint_loc)
# model = finetune.FinetuneableZoobotClassifier(encoder=encoder, ...)

# retrain to find rings
trainer = finetune.get_trainer(save_dir, accelerator='cpu', max_epochs=1)
trainer.fit(model, datamodule)
# can now use this model or saved checkpoint to make predictions on new data. Well done!

# pretending we want to load from scratch:
best_checkpoint = trainer.checkpoint_callback.best_model_path
finetuned_model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(best_checkpoint)

from zoobot.pytorch.predictions import predict_on_catalog

predict_on_catalog.predict(
test_catalog,
finetuned_model,
n_samples=1,
label_cols=label_cols,
save_loc=os.path.join(save_dir, 'finetuned_predictions.csv')
# trainer_kwargs={'accelerator': 'gpu'}
)
"""
Under the hood, this is essentially doing:
import pytorch_lightning as pl
predict_trainer = pl.Trainer(devices=1, max_epochs=-1)
predict_datamodule = GalaxyDataModule(
label_cols=None, # important, else you will get "conv2d() received an invalid combination of arguments"
predict_catalog=test_catalog,
batch_size=32
)
preds = predict_trainer.predict(finetuned_model, predict_datamodule)
print(preds)
"""
13 changes: 10 additions & 3 deletions zoobot/pytorch/training/finetune.py
Expand Up @@ -269,9 +269,16 @@ def __init__(
self.loss = partial(cross_entropy_loss,
weight=class_weights,
label_smoothing=self.label_smoothing)
self.train_acc = tm.Accuracy(task='binary', average="micro")
self.val_acc = tm.Accuracy(task='binary', average="micro")
self.test_acc = tm.Accuracy(task='binary', average="micro")
logging.info(f'num_classes: {num_classes}')
if num_classes == 2:
logging.info('Using binary classification')
task = 'binary'
else:
logging.info('Using multi-class classification')
task = 'multiclass'
self.train_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
self.val_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)

def step_to_dict(self, y, y_pred, loss):
y_class_preds = torch.argmax(y_pred, axis=1)
Expand Down
7 changes: 4 additions & 3 deletions zoobot/shared/benchmark_datasets.py
Expand Up @@ -23,21 +23,22 @@ def get_gz_decals_dr5_benchmark_dataset(data_dir, random_state, download):
return schema, (train_catalog, val_catalog, test_catalog)


def get_gz_evo_benchmark_dataset(data_dir, random_state, download=False, debug=False, datasets=['gz_desi', 'gz_hubble', 'gz_candels', 'gz2', 'gz_rings']):
def get_gz_evo_benchmark_dataset(data_dir, random_state, download=False, debug=False, datasets=['gz_desi', 'gz_hubble', 'gz_candels', 'gz2', 'gz_rings', 'gz_cosmic_dawn']):

from foundation.datasets import mixed # not yet public. import will fail if you're not me.

# temporarily, everything *but* hubble, for Ben
# datasets = ['gz_desi', 'gz_candels', 'gz2', 'gz_rings']
datasets = ['gz_desi', 'gz_candels', 'gz_hubble', 'gz2', 'gz_rings']

_, (temp_train_catalog, temp_val_catalog, _) = mixed.everything_all_dirichlet_with_rings(data_dir, debug, download=download, use_cache=True, datasets=datasets)
# TODO temporarily no cache, to remake
direct_label_cols, (temp_train_catalog, temp_val_catalog, _) = mixed.everything_all_dirichlet_with_rings(data_dir, debug, download=download, use_cache=True, datasets=datasets)
canonical_train_catalog = pd.concat([temp_train_catalog, temp_val_catalog], axis=0)

# here I'm going to ignore the test catalog
train_catalog, hidden_catalog = train_test_split(canonical_train_catalog, test_size=1./3., random_state=random_state)
val_catalog, test_catalog = train_test_split(hidden_catalog, test_size=2./3., random_state=random_state)

schema = mixed.mixed_schema()
assert len(direct_label_cols) == len(schema.label_cols), ValueError((len(direct_label_cols), len(schema)))
logging.info('Schema: {}'.format(schema))
return schema, (train_catalog, val_catalog,test_catalog)
7 changes: 6 additions & 1 deletion zoobot/shared/load_predictions.py
Expand Up @@ -93,7 +93,6 @@ def prediction_hdf5_to_summary_parquet(hdf5_loc: str, save_loc: str, schema: sch
"""
assert isinstance(hdf5_loc, str)

label_cols = schema.label_cols

# concentrations will be of (galaxy, question, model, forward_pass) after going through c_group
# may be only one model but will still have that dimension (e.g. 1000, 39, 1, 5)
Expand All @@ -105,6 +104,12 @@ def prediction_hdf5_to_summary_parquet(hdf5_loc: str, save_loc: str, schema: sch
galaxy_id_df = galaxy_id_df[:100000]
save_loc = save_loc.replace('.parquet', '_debug.parquet')

label_cols = schema.label_cols
# TODO optionally ignore all but a subset of columns, for models without finetuning
# hdf5_label_cols = label_cols
# valid_cols = [col for col in hdf5_label_cols if col in label_col_subset]
# concentrations = concentrations[:, valid_cols]

# applies to all questions at once
# hopefully also supports 3D concentrations (galaxy/question/model/pass)
logging.info('Concentrations: {}'.format(concentrations.shape))
Expand Down
3 changes: 3 additions & 0 deletions zoobot/shared/schemas.py
Expand Up @@ -268,6 +268,9 @@ def answers(self):
gz_candels_ortho_schema = Schema(label_metadata.candels_ortho_pairs, label_metadata.candels_ortho_dependencies)
gz_hubble_ortho_schema = Schema(label_metadata.hubble_ortho_pairs, label_metadata.hubble_ortho_dependencies)
cosmic_dawn_ortho_schema = Schema(label_metadata.cosmic_dawn_ortho_pairs , label_metadata.cosmic_dawn_ortho_dependencies)

# schemas without orthogonal question suffix (-cd, -dr8, etc)
cosmic_dawn_schema = Schema(label_metadata.cosmic_dawn_pairs , label_metadata.cosmic_dawn_dependencies)
gz_rings_schema = Schema(label_metadata.rings_pairs, label_metadata.rings_dependencies)
desi_schema = Schema(label_metadata.desi_pairs, label_metadata.desi_dependencies) # for DESI data release prediction users, not for ML training - no -dr5, -dr8, etc
# note that as this is a call to Schema (and Question and Answer), any logging within those will
Expand Down

0 comments on commit ad78db8

Please sign in to comment.