Skip to content

Commit

Permalink
try on colab
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 19, 2024
1 parent e0fd96d commit bb6c403
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/test_from_hub.py
Expand Up @@ -16,7 +16,7 @@ def test_get_finetuned():
from huggingface_hub import hf_hub_download

REPO_ID = "mwalmsley/zoobot-finetuned-is_tidal"
FILENAME = "4.ckpt"
FILENAME = "FinetuneableZoobotClassifier.ckpt"

downloaded_loc = hf_hub_download(
repo_id=REPO_ID,
Expand All @@ -26,7 +26,12 @@ def test_get_finetuned():
model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(downloaded_loc, map_location='cpu') # hub_name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)

def test_get_finetuned_class_method():

from zoobot.pytorch.training import finetune

model = finetune.FinetuneableZoobotClassifier.load_from_name('mwalmsley/zoobot-finetuned-is_tidal', map_location='cpu')
assert model(torch.rand(1, 3, 224, 224)).shape == (1, 2)

# def test_get_finetuned_from_local():
# # checkpoint_loc = '/home/walml/repos/zoobot/tests/convnext_nano_finetuned_linear_is-lsb.ckpt'
Expand Down
24 changes: 24 additions & 0 deletions zoobot/pytorch/training/finetune.py
Expand Up @@ -377,6 +377,13 @@ def on_test_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx

def upload_images_to_wandb(self, outputs, batch, batch_idx):
raise NotImplementedError('Must be subclassed')

@classmethod
def load_from_name(cls, name: str, **kwargs):
downloaded_loc = download_from_name(cls.__name__, name, **kwargs)
return cls.load_from_checkpoint(downloaded_loc, **kwargs) # trained on GPU, may need map_location='cpu' if you get a device error





Expand All @@ -396,6 +403,8 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract):
"""



def __init__(
self,
num_classes: int,
Expand Down Expand Up @@ -762,3 +771,18 @@ def get_trainer(
)

return trainer


def download_from_name(class_name: str, hub_name: str, **kwargs):
from huggingface_hub import hf_hub_download

if hub_name.startswith('hf_hub:'):
logging.info('Passed name with hf_hub: prefix, dropping prefix')
repo_id = hub_name.split('hf_hub:')[1]
else:
repo_id = hub_name
downloaded_loc = hf_hub_download(
repo_id=repo_id,
filename=f"{class_name}.ckpt"
)
return downloaded_loc

0 comments on commit bb6c403

Please sign in to comment.