Skip to content

Commit

Permalink
maybe it's the datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent 600ed3e commit e0e56cd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def on_validation_epoch_end(self) -> None:
self.log_all_metrics(subset='validation')

def on_test_epoch_end(self) -> None:
logging.info('start test epoch end')
self.log_all_metrics(subset='test')
logging.info('end test epoch end')

def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name):
raise NotImplementedError('Must be subclassed')
Expand Down
18 changes: 16 additions & 2 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,24 @@ def train_default_zoobot_from_scratch(
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
datamodule.setup(stage='test')
test_datamodule = webdatamodule.WebDataModule(
train_urls=None,
val_urls=None,
test_urls=test_urls,
label_cols=schema.label_cols,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
cache_dir=cache_dir,
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop
)
test_datamodule.setup(stage='test')
test_trainer.test(
model=lightning_model,
datamodule=datamodule,
datamodule=test_datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# TODO may need to remake on 1 gpu only
Expand Down

0 comments on commit e0e56cd

Please sign in to comment.