From bb6c40311366cba82eb7a1eeed593638411ecf29 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Tue, 19 Mar 2024 19:33:15 -0400 Subject: [PATCH] try on colab --- tests/test_from_hub.py | 7 ++++++- zoobot/pytorch/training/finetune.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/test_from_hub.py b/tests/test_from_hub.py index 9bc1c115..159f22c5 100644 --- a/tests/test_from_hub.py +++ b/tests/test_from_hub.py @@ -16,7 +16,7 @@ def test_get_finetuned(): from huggingface_hub import hf_hub_download REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal" - FILENAME = "4.ckpt" + FILENAME = "FinetuneableZoobotClassifier.ckpt" downloaded_loc = hf_hub_download( repo_id=REPO_ID, @@ -26,7 +26,12 @@ def test_get_finetuned(): model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2) +def test_get_finetuned_class_method(): + from zoobot.pytorch.training import finetune + + model = finetune.FinetuneableZoobotClassifier.load_from_name('mwalmsley/zoobot-finetuned-is_tidal', map_location='cpu') + assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2) # def test_get_finetuned_from_local(): # # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt' diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 0ef638fb..dd7a66a3 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -377,6 +377,13 @@ def on_test_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx def upload_images_to_wandb(self, outputs, batch, batch_idx): raise NotImplementedError('Must be subclassed') + + @classmethod + def load_from_name(cls, name: str, **kwargs): + downloaded_loc = download_from_name(cls.__name__, name, **kwargs) + return cls.load_from_checkpoint(downloaded_loc, **kwargs) # trained on GPU, may need map_location='cpu' if you get a device error + + @@ -396,6 +403,8 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract): """ + + def __init__( self, num_classes: int, @@ -762,3 +771,18 @@ def get_trainer( ) return trainer + + +def download_from_name(class_name: str, hub_name: str, **kwargs): + from huggingface_hub import hf_hub_download + + if hub_name.startswith('hf_hub:'): + logging.info('Passed name with hf_hub: prefix, dropping prefix') + repo_id = hub_name.split('hf_hub:')[1] + else: + repo_id = hub_name + downloaded_loc = hf_hub_download( + repo_id=repo_id, + filename=f"{class_name}.ckpt" + ) + return downloaded_loc \ No newline at end of file