Skip to content

Commit

Permalink
check if it is logging or wds
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent 329a88c commit de7562b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 41 deletions.
4 changes: 2 additions & 2 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def on_validation_epoch_end(self) -> None:

def on_test_epoch_end(self) -> None:
logging.info('start test epoch end')
self.log_all_metrics(subset='test')
# self.log_all_metrics(subset='test')
logging.info('end test epoch end')

def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name):
Expand All @@ -140,7 +140,7 @@ def log_all_metrics(self, subset=None):
if subset is not None:
for name, metric in self.loss_metrics.items():
if subset in name:
print('logging', name)
logging.info(name)
self.log(name, metric, on_epoch=True, on_step=False, prog_bar=True, logger=True)
else: # just log everything
self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True)
Expand Down
78 changes: 39 additions & 39 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,45 +335,45 @@ def train_default_zoobot_from_scratch(

# can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs.
# also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting
if datamodule.test_dataloader is not None:
logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# test_trainer.validate(
# model=lightning_model,
# datamodule=datamodule,
# ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# )
test_trainer = pl.Trainer(
accelerator=accelerator,
devices=1,
precision=precision,
logger=wandb_logger,
default_root_dir=save_dir
)
if test_trainer.is_global_zero:
test_datamodule = webdatamodule.WebDataModule(
train_urls=None,
val_urls=None,
test_urls=test_urls,
label_cols=schema.label_cols,
batch_size=batch_size,
num_workers=1, # 20 / 5 == 4, /2=2
prefetch_factor=prefetch_factor,
cache_dir=None,
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop
)
test_datamodule.setup(stage='test')
test_trainer.test(
model=lightning_model,
datamodule=test_datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
else:
logging.info('Not global zero, skipping test metrics')
else:
logging.info('No test dataloader found, skipping test metrics')
# if datamodule.test_dataloader is not None:
# logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...')
# # test_trainer.validate(
# # model=lightning_model,
# # datamodule=datamodule,
# # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
# # )
# test_trainer = pl.Trainer(
# accelerator=accelerator,
# devices=1,
# precision=precision,
# logger=wandb_logger,
# default_root_dir=save_dir
# )
# if test_trainer.is_global_zero:
# test_datamodule = webdatamodule.WebDataModule(
# train_urls=None,
# val_urls=None,
# test_urls=test_urls,
# label_cols=schema.label_cols,
# batch_size=batch_size,
# num_workers=1, # 20 / 5 == 4, /2=2
# prefetch_factor=prefetch_factor,
# cache_dir=None,
# color=color,
# crop_scale_bounds=crop_scale_bounds,
# crop_ratio_bounds=crop_ratio_bounds,
# resize_after_crop=resize_after_crop
# )
datamodule.setup(stage='test')
trainer.test(
model=lightning_model,
datamodule=datamodule,
ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# else:
# logging.info('Not global zero, skipping test metrics')
# else:
# logging.info('No test dataloader found, skipping test metrics')


# explicitly update the model weights to the best checkpoint before returning
Expand Down

0 comments on commit de7562b

Please sign in to comment.