Skip to content

Commit

Permalink
revert datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent d3b4398 commit 372bdb1
Showing 1 changed file with 13 additions and 22 deletions.
35 changes: 13 additions & 22 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Expand Up @@ -237,21 +237,11 @@ def train_default_zoobot_from_scratch(
# this branch will use WebDataModule to load premade webdatasets

# temporary: use SSL-like transform
from foundation.models import transforms
# from omegaconf import DictConfig
# cfg = DictConfig({
# 'aug': {
# 'global_transform_0': {
# 'interpolation': 'bilinear',
# 'random_affine': {} # etc
# }

# }
# })
train_transform_cfg = transforms.default_view_config()
inference_transform_cfg = transforms.minimal_view_config()
train_transform_cfg.output_size = resize_after_crop
inference_transform_cfg.output_size = resize_after_crop
# from foundation.models import transforms
# train_transform_cfg = transforms.default_view_config()
# inference_transform_cfg = transforms.minimal_view_config()
# train_transform_cfg.output_size = resize_after_crop
# inference_transform_cfg.output_size = resize_after_crop

datamodule = webdatamodule.WebDataModule(
train_urls=train_urls,
Expand All @@ -264,12 +254,13 @@ def train_default_zoobot_from_scratch(
prefetch_factor=prefetch_factor,
cache_dir=cache_dir,
# augmentation args
train_transform=transforms.GalaxyViewTransform(train_transform_cfg),
inference_transform=transforms.GalaxyViewTransform(inference_transform_cfg),
# color=color,
# crop_scale_bounds=crop_scale_bounds,
# crop_ratio_bounds=crop_ratio_bounds,
# resize_after_crop=resize_after_crop
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop,
# temporary: use SSL-like transform
# train_transform=transforms.GalaxyViewTransform(train_transform_cfg),
# inference_transform=transforms.GalaxyViewTransform(inference_transform_cfg),
)

datamodule.setup(stage='fit')
Expand Down Expand Up @@ -352,7 +343,7 @@ def train_default_zoobot_from_scratch(

# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if test_catalog is not None:
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
test_trainer.validate(
model=lightning_model,
Expand Down

0 comments on commit 372bdb1

Please sign in to comment.