Skip to content

Commit

Permalink
Merge branch 'narval-migration' of https://github.com/mwalmsley/zoobot
Browse files Browse the repository at this point in the history
…into narval-migration
  • Loading branch information
mwalmsley committed Jan 4, 2024
2 parents ee608cf + d99d193 commit 7d1f379
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 187 deletions.
51 changes: 29 additions & 22 deletions zoobot/pytorch/datasets/webdatamodule.py
Expand Up @@ -87,14 +87,6 @@ def do_transform(img):
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
return do_transform

def make_label_transform(self):
if self.label_cols is not None:
def label_transform(label_dict):
return torch.from_numpy(np.array([label_dict.get(col, 0) for col in self.label_cols])).double()
return label_transform
else:
return identity # do nothing


def make_loader(self, urls, mode="train"):
logging.info('Making loader with mode {}'.format(mode))
Expand All @@ -108,7 +100,7 @@ def make_loader(self, urls, mode="train"):

transform_image = self.make_image_transform(mode=mode)

transform_label = self.make_label_transform()
transform_label = dict_to_label_cols_factory(self.label_cols)

dataset = wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func)
# https://webdataset.github.io/webdataset/multinode/
Expand Down Expand Up @@ -138,17 +130,6 @@ def make_loader(self, urls, mode="train"):
# so use the torch collate instead
dataset = dataset.batched(self.batch_size, torch.utils.data.default_collate, partial=False)

loader = wds.WebLoader(
dataset,
batch_size=None, # already batched
shuffle=False, # already shuffled
num_workers=self.num_workers,
pin_memory=True,
prefetch_factor=self.prefetch_factor
)

loader.length = dataset_size // self.batch_size

# temp hack instead
if mode in ['train', 'val']:
assert dataset_size % self.batch_size == 0, (dataset_size, self.batch_size, dataset_size % self.batch_size)
Expand All @@ -159,6 +140,8 @@ def make_loader(self, urls, mode="train"):
# loader = loader.ddp_equalize(dataset_size // self.batch_size)
# print("# loader length", len(loader))

loader = webdataset_to_webloader(dataset, self.num_workers, self.prefetch_factor)

return loader

def train_dataloader(self):
Expand All @@ -176,7 +159,7 @@ def identity(x):
def nodesplitter_func(urls):
urls_to_use = list(wds.split_by_node(urls)) # rely on WDS for the hard work
rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
logging.info(
logging.debug(
f'''
Splitting urls within webdatamodule with WORLD_SIZE:
{world_size}, RANK: {rank}, WORKER: {worker} of {num_workers}\n
Expand All @@ -186,6 +169,7 @@ def nodesplitter_func(urls):
return urls_to_use

def interpret_shard_size_from_url(url):
assert isinstance(url, str), TypeError(url)
return int(url.rstrip('.tar').split('_')[-1])

def interpret_dataset_size_from_urls(urls):
Expand All @@ -197,4 +181,27 @@ def get_first(x):
def custom_collate(x):
if isinstance(x, list) and len(x) == 1:
x = x[0]
return torch.utils.data.default_collate(x)
return torch.utils.data.default_collate(x)


def webdataset_to_webloader(dataset, num_workers, prefetch_factor):
loader = wds.WebLoader(
dataset,
batch_size=None, # already batched
shuffle=False, # already shuffled
num_workers=num_workers,
pin_memory=True,
prefetch_factor=prefetch_factor
)

# loader.length = dataset_size // batch_size
return loader


def dict_to_label_cols_factory(label_cols=None):
if label_cols is not None:
def label_transform(label_dict):
return torch.from_numpy(np.array([label_dict.get(col, 0) for col in label_cols])).double() # gets cast to int in zoobot loss
return label_transform
else:
return identity # do nothing
72 changes: 37 additions & 35 deletions zoobot/pytorch/datasets/webdataset_utils.py
@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import Union, Callable
import os
import cv2
import json
Expand All @@ -22,26 +22,26 @@
import zoobot.pytorch.datasets.webdatamodule as webdatamodule


def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
n_shards = len(catalog) // divisor
logging.info(n_shards)
# def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
# for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
# n_shards = len(catalog) // divisor
# logging.info(n_shards)

catalog = catalog[:n_shards*divisor]
logging.info(len(catalog))
# catalog = catalog[:n_shards*divisor]
# logging.info(len(catalog))

# wds_dir e.g. /home/walml/data/wds
# # wds_dir e.g. /home/walml/data/wds

save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically
# save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically

df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)
# some tests, if you like
# webdataset_utils.load_wds_directly(save_loc)
# webdataset_utils.load_wds_with_augmentation(save_loc)
# webdataset_utils.load_wds_with_webdatamodule([save_loc], label_cols)
# df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)
# # some tests, if you like
# # webdataset_utils.load_wds_directly(save_loc)
# # webdataset_utils.load_wds_with_augmentation(save_loc)
# # webdataset_utils.load_wds_with_webdatamodule([save_loc], label_cols)


def make_mock_wds(save_dir: str, label_cols: List, n_shards: int, shard_size: int):
def make_mock_wds(save_dir: str, label_cols: list, n_shards: int, shard_size: int):
counter = 0
shards = [os.path.join(save_dir, f'mock_shard_{shard_n}_{shard_size}.tar') for shard_n in range(n_shards)]
for shard in shards:
Expand Down Expand Up @@ -103,47 +103,49 @@ def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse
for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)):
shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar')
if overwrite or not(os.path.isfile(shard_save_loc)):

if sparse_label_df is not None:
shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # auto-merge
shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # type: ignore # auto-merge

assert not any(shard_df[label_cols].isna().max())
assert not any(shard_df[label_cols].isna().max()) # type: ignore

# logging.info(shard_save_loc)
sink = wds.TarWriter(shard_save_loc)
for _, galaxy in shard_df.iterrows():
for _, galaxy in shard_df.iterrows(): # type: ignore
sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform))
sink.close()


def galaxy_to_wds(galaxy: pd.Series, label_cols, transform=None):
def galaxy_to_wds(galaxy: pd.Series, label_cols: Union[list[str],None]=None, metadata_cols: Union[list, None]=None, transform: Union[Callable, None]=None):

assert os.path.isfile(galaxy['file_loc']), galaxy['file_loc']
im = cv2.imread(galaxy['file_loc'])
# cv2 loads BGR for 'history', fix
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
assert not np.any(np.isnan(np.array(im))), galaxy['file_loc']
# if central_crop is not None:
# width, height, _ = im.shape
# # assert width == height, (width, height)
# mid = int(width/2)
# half_central_crop = int(central_crop/2)
# low_edge, high_edge = mid - half_central_crop, mid + half_central_crop
# im = im[low_edge:high_edge, low_edge:high_edge]
# assert im.shape == (central_crop, central_crop, 3)

# apply albumentations
if transform is not None:
im = transform(image=im)['image']

labels = json.dumps(galaxy[label_cols].astype(np.int32).to_dict())
id_str = str(galaxy['id_str'])

if transform is not None:
im = transform(image=im)['image']

if label_cols is None:
labels = json.dumps({})
else:
labels = json.dumps(galaxy[label_cols].to_dict())

if metadata_cols is None:
metadata = json.dumps({})
else:
metadata = json.dumps(galaxy[metadata_cols].to_dict())

return {
"__key__": id_str,
"__key__": id_str, # silly wds bug where if __key__ ends .jpg, all keys get jpg. prepended?! use id_str instead
"image.jpg": im,
"labels.json": labels
"labels.json": labels,
"metadata.json": metadata
}


# just for debugging
def load_wds_directly(wds_loc, max_to_load=3):

Expand Down

0 comments on commit 7d1f379

Please sign in to comment.