Skip to content

Commit

Permalink
Merge branch 'narval-migration' of github.com:mwalmsley/zoobot into n…
Browse files Browse the repository at this point in the history
…arval-migration
  • Loading branch information
mwalmsley committed Jan 12, 2024
2 parents 80517d5 + 7db406c commit ce106f7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 41 deletions.
13 changes: 8 additions & 5 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(
crop_scale_bounds=(0.7, 0.8),
crop_ratio_bounds=(0.9, 1.1),
resize_after_crop=224,
transform: Callable=None
train_transform: Callable=None,
inference_transform: Callable=None
):
super().__init__()

Expand Down Expand Up @@ -61,7 +62,8 @@ def __init__(
self.crop_scale_bounds = crop_scale_bounds
self.crop_ratio_bounds = crop_ratio_bounds

self.transform = transform
self.train_transform = train_transform
self.inference_transform = inference_transform

for url_name in ['train', 'val', 'test', 'predict']:
urls = getattr(self, f'{url_name}_urls')
Expand Down Expand Up @@ -101,12 +103,12 @@ def make_loader(self, urls, mode="train"):
assert mode in ['val', 'test', 'predict'], mode
shuffle = 0

if self.transform is None:
if self.train_transform is None:
logging.info('Using default transform')
transform_image = self.make_image_transform(mode=mode)
else:
logging.info('Ignoring hparams and using directly-passed transform')
transform_image = self.transform
logging.info('Ignoring hparams and using directly-passed transforms')
transform_image = self.train_transform if mode == 'train' else self.inference_transform

transform_label = dict_to_label_cols_factory(self.label_cols)

Expand All @@ -130,6 +132,7 @@ def make_loader(self, urls, mode="train"):
logging.info('Will return id_str only')
dataset = dataset.to_tuple('__key__')
else:

dataset = (
dataset.to_tuple('image.jpg', 'labels.json')
.map_tuple(transform_image, transform_label)
Expand Down
60 changes: 28 additions & 32 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,19 @@

import zoobot.pytorch.datasets.webdatamodule as webdatamodule

def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
n_shards = len(catalog) // divisor
logging.info(n_shards)

# def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
# for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
# n_shards = len(catalog) // divisor
# logging.info(n_shards)
catalog = catalog[:n_shards*divisor]
logging.info(len(catalog))

# catalog = catalog[:n_shards*divisor]
# logging.info(len(catalog))

# # wds_dir e.g. /home/walml/data/wds

# save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically
save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically

# df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)
# # some tests, if you like
# # webdataset_utils.load_wds_directly(save_loc)
# # webdataset_utils.load_wds_with_augmentation(save_loc)
# # webdataset_utils.load_wds_with_webdatamodule([save_loc], label_cols)


def make_mock_wds(save_dir: str, label_cols: list, n_shards: int, shard_size: int):
counter = 0
shards = [os.path.join(save_dir, f'mock_shard_{shard_n}_{shard_size}.tar') for shard_n in range(n_shards)]
for shard in shards:
sink = wds.TarWriter(shard)
for galaxy_n in range(shard_size):
data = {
"__key__": f'id_{galaxy_n}',
"image.jpg": (np.random.rand(424, 424)*255.).astype(np.uint8),
"labels.json": json.dumps(dict(zip(label_cols, [np.random.randint(low=0, high=10) for _ in range(len(label_cols))])))
}
sink.write(data)
counter += 1
print(counter)
return shards
df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)




def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse_label_df=None, overwrite=False):
Expand Down Expand Up @@ -206,6 +183,25 @@ def identity(x):
# no lambda to be pickleable
return x



def make_mock_wds(save_dir: str, label_cols: list, n_shards: int, shard_size: int):
counter = 0
shards = [os.path.join(save_dir, f'mock_shard_{shard_n}_{shard_size}.tar') for shard_n in range(n_shards)]
for shard in shards:
sink = wds.TarWriter(shard)
for galaxy_n in range(shard_size):
data = {
"__key__": f'id_{galaxy_n}',
"image.jpg": (np.random.rand(424, 424)*255.).astype(np.uint8),
"labels.json": json.dumps(dict(zip(label_cols, [np.random.randint(low=0, high=10) for _ in range(len(label_cols))])))
}
sink.write(data)
counter += 1
print(counter)
return shards


if __name__ == '__main__':

save_dir = '/home/walml/repos/temp'
Expand Down
10 changes: 6 additions & 4 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ def train_default_zoobot_from_scratch(

# }
# })
cfg = transforms.default_view_config()
cfg.output_size = resize_after_crop
transform = transforms.GalaxyViewTransform(cfg)
train_transform_cfg = transforms.default_view_config()
inference_transform_cfg = transforms.minimal_view_config()
train_transform_cfg.output_size = resize_after_crop
inference_transform_cfg.output_size = resize_after_crop

datamodule = webdatamodule.WebDataModule(
train_urls=train_urls,
Expand All @@ -263,7 +264,8 @@ def train_default_zoobot_from_scratch(
prefetch_factor=prefetch_factor,
cache_dir=cache_dir,
# augmentation args
transform=transform,
train_transform=transforms.GalaxyViewTransform(train_transform_cfg),
inference_transform=transforms.GalaxyViewTransform(inference_transform_cfg),
# color=color,
# crop_scale_bounds=crop_scale_bounds,
# crop_ratio_bounds=crop_ratio_bounds,
Expand Down

0 comments on commit ce106f7

Please sign in to comment.