Skip to content

Commit

Permalink
debug test metrics not appearing
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent c0b2e41 commit d7386be
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def val_dataloader(self):
return self.make_loader(self.val_urls, mode="val")

def test_dataloader(self):
return self.make_loader(self.val_urls, mode="test")
return self.make_loader(self.test_urls, mode="test")

def predict_dataloader(self):
return self.make_loader(self.predict_urls, mode="predict")
Expand Down
14 changes: 9 additions & 5 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,20 @@ def train_default_zoobot_from_scratch(
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
test_trainer.validate(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# test_trainer.validate(
# model=lightning_model,
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
datamodule.setup(stage='test')
# temp
print(datamodule.test_urls)
test_trainer.test(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# TODO may need to remake on 1 gpu only

# explicitly update the model weights to the best checkpoint before returning
# (assumes only one checkpoint callback, very likely in practice)
Expand Down

0 comments on commit d7386be

Please sign in to comment.