Skip to content

Commit

Permalink
num_workers=1, move test_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent e0e56cd commit 329a88c
Showing 1 changed file with 40 additions and 41 deletions.
81 changes: 40 additions & 41 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,51 +331,50 @@ def train_default_zoobot_from_scratch(

trainer.fit(lightning_model, datamodule) # uses batch size of datamodule

test_trainer = pl.Trainer(
accelerator=accelerator,
devices=1,
precision=precision,
logger=wandb_logger,
default_root_dir=save_dir
)

best_model_path = trainer.checkpoint_callback.best_model_path

if test_trainer.is_global_zero:
# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# test_trainer.validate(
# model=lightning_model,
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
test_datamodule = webdatamodule.WebDataModule(
train_urls=None,
val_urls=None,
test_urls=test_urls,
label_cols=schema.label_cols,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
cache_dir=cache_dir,
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop
)
test_datamodule.setup(stage='test')
test_trainer.test(
model=lightning_model,
datamodule=test_datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# TODO may need to remake on 1 gpu only
# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# test_trainer.validate(
# model=lightning_model,
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
test_trainer = pl.Trainer(
accelerator=accelerator,
devices=1,
precision=precision,
logger=wandb_logger,
default_root_dir=save_dir
)
if test_trainer.is_global_zero:
test_datamodule = webdatamodule.WebDataModule(
train_urls=None,
val_urls=None,
test_urls=test_urls,
label_cols=schema.label_cols,
batch_size=batch_size,
num_workers=1, # 20 / 5 == 4, /2=2
prefetch_factor=prefetch_factor,
cache_dir=None,
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop
)
test_datamodule.setup(stage='test')
test_trainer.test(
model=lightning_model,
datamodule=test_datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
else:
logging.info('No test dataloader found, skipping test metrics')
logging.info('Not global zero, skipping test metrics')
else:
logging.info('Not global zero, skipping test metrics')
logging.info('No test dataloader found, skipping test metrics')


# explicitly update the model weights to the best checkpoint before returning
# (assumes only one checkpoint callback, very likely in practice)
Expand Down

0 comments on commit 329a88c

Please sign in to comment.