Skip to content

Commit

Permalink
seems to block on multi-g
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent 9a00271 commit 600ed3e
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,24 +341,27 @@ def train_default_zoobot_from_scratch(

best_model_path = trainer.checkpoint_callback.best_model_path

# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# test_trainer.validate(
# model=lightning_model,
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
datamodule.setup(stage='test')
# temp
print(datamodule.test_urls)
test_trainer.test(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# TODO may need to remake on 1 gpu only
if test_trainer.is_global_zero:
# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# test_trainer.validate(
# model=lightning_model,
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
datamodule.setup(stage='test')
test_trainer.test(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# TODO may need to remake on 1 gpu only
else:
logging.info('No test dataloader found, skipping test metrics')
else:
logging.info('Not global zero, skipping test metrics')

# explicitly update the model weights to the best checkpoint before returning
# (assumes only one checkpoint callback, very likely in practice)
Expand Down

0 comments on commit 600ed3e

Please sign in to comment.