Skip to content

Commit

Permalink
try with aug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 16, 2024
1 parent 53974da commit c8442db
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
20 changes: 17 additions & 3 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections import defaultdict
from typing import Callable
import logging
import torch.utils.data
Expand Down Expand Up @@ -89,6 +90,8 @@ def make_image_transform(self, mode="train"):


def do_transform(img):
assert img.shape[2] < 4 # 1 or 3 channels in shape[2] dim, i.e. numpy/pil HWC convention
# if not, check decode mode is 'rgb' not 'torchrgb'
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
return do_transform

Expand All @@ -105,11 +108,14 @@ def make_loader(self, urls, mode="train"):

if self.train_transform is None:
logging.info('Using default transform')
decode_mode = 'rgb' # np.array, for albumentations
transform_image = self.make_image_transform(mode=mode)
else:
logging.info('Ignoring hparams and using directly-passed transforms')
decode_mode = 'torchrgb' # tensor, for torchvision
transform_image = self.train_transform if mode == 'train' else self.inference_transform


transform_label = dict_to_label_cols_factory(self.label_cols)

dataset = wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func)
Expand All @@ -119,8 +125,7 @@ def make_loader(self, urls, mode="train"):
if shuffle > 0:
dataset = dataset.shuffle(shuffle)

# dataset = dataset.decode("rgb") # np.array, for albumentations
dataset = dataset.decode("torchrgb") # tensor, for torchvision
dataset = dataset.decode(decode_mode)

if mode == 'predict':
if self.label_cols != ['id_str']:
Expand Down Expand Up @@ -222,9 +227,18 @@ def label_transform(label_dict):
return identity # do nothing

def dict_to_filled_dict_factory(label_cols):
logging.info(f'label cols: {label_cols}')
# might be a little slow, but very safe
def label_transform(label_dict: dict):

# modifies inplace with 0 iff key missing
[label_dict.setdefault(col, 0) for col in label_cols]
# [label_dict.setdefault(col, 0) for col in label_cols]

for col in label_cols:
label_dict[col] = label_dict.get(col, 0)

# label_dict_with_default = defaultdict(0)
# label_dict_with_default.update(label_dict)

return label_dict
return label_transform
3 changes: 2 additions & 1 deletion zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore n


def forward(self, x):
assert x.shape[1] < 4 # torchlike BCHW
x = self.encoder(x)
return self.head(x)

Expand Down Expand Up @@ -142,7 +143,7 @@ def log_all_metrics(self, subset=None):
prog_bar = metric_collection == self.loss_metrics
for name, metric in metric_collection.items():
if subset in name:
logging.info(name)
# logging.info(name)
self.log(name, metric, on_epoch=True, on_step=False, prog_bar=prog_bar, logger=True)
else: # just log everything
self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True)
Expand Down

0 comments on commit c8442db

Please sign in to comment.