Skip to content

Commit

Permalink
tweaks for foundation
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Dec 25, 2023
1 parent e8aa6b6 commit d94150e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
48 changes: 27 additions & 21 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ def do_transform(img):
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
return do_transform

def make_label_transform(self):
if self.label_cols is not None:
def label_transform(label_dict):
return torch.from_numpy(np.array([label_dict.get(col, 0) for col in self.label_cols])).double()
return label_transform
else:
return identity # do nothing


def make_loader(self, urls, mode="train"):
logging.info('Making loader with mode {}'.format(mode))
Expand All @@ -108,7 +100,7 @@ def make_loader(self, urls, mode="train"):

transform_image = self.make_image_transform(mode=mode)

transform_label = self.make_label_transform()
transform_label = dict_to_label_cols_factory(self.label_cols)

dataset = wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func)
# https://webdataset.github.io/webdataset/multinode/
Expand Down Expand Up @@ -138,17 +130,6 @@ def make_loader(self, urls, mode="train"):
# so use the torch collate instead
dataset = dataset.batched(self.batch_size, torch.utils.data.default_collate, partial=False)

loader = wds.WebLoader(
dataset,
batch_size=None, # already batched
shuffle=False, # already shuffled
num_workers=self.num_workers,
pin_memory=True,
prefetch_factor=self.prefetch_factor
)

loader.length = dataset_size // self.batch_size

# temp hack instead
if mode in ['train', 'val']:
assert dataset_size % self.batch_size == 0, (dataset_size, self.batch_size, dataset_size % self.batch_size)
Expand All @@ -159,6 +140,8 @@ def make_loader(self, urls, mode="train"):
# loader = loader.ddp_equalize(dataset_size // self.batch_size)
# print("# loader length", len(loader))

loader = webdataset_to_webloader(dataset, self.num_workers, self.prefetch_factor)

return loader

def train_dataloader(self):
Expand Down Expand Up @@ -197,4 +180,27 @@ def get_first(x):
def custom_collate(x):
if isinstance(x, list) and len(x) == 1:
x = x[0]
return torch.utils.data.default_collate(x)
return torch.utils.data.default_collate(x)


def webdataset_to_webloader(dataset, num_workers, prefetch_factor):
loader = wds.WebLoader(
dataset,
batch_size=None, # already batched
shuffle=False, # already shuffled
num_workers=num_workers,
pin_memory=True,
prefetch_factor=prefetch_factor
)

# loader.length = dataset_size // batch_size
return loader


def dict_to_label_cols_factory(label_cols=None):
if label_cols is not None:
def label_transform(label_dict):
return torch.from_numpy(np.array([label_dict.get(col, 0) for col in label_cols])).double() # gets cast to int in zoobot loss
return label_transform
else:
return identity # do nothing
28 changes: 14 additions & 14 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@
import zoobot.pytorch.datasets.webdatamodule as webdatamodule


def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
n_shards = len(catalog) // divisor
logging.info(n_shards)
# def catalogs_to_webdataset(dataset_name, wds_dir, label_cols, train_catalog, test_catalog, sparse_label_df=None, divisor=2048, overwrite=False):
# for (catalog_name, catalog) in [('train', train_catalog), ('test', test_catalog)]:
# n_shards = len(catalog) // divisor
# logging.info(n_shards)

catalog = catalog[:n_shards*divisor]
logging.info(len(catalog))
# catalog = catalog[:n_shards*divisor]
# logging.info(len(catalog))

# wds_dir e.g. /home/walml/data/wds
# # wds_dir e.g. /home/walml/data/wds

save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically
# save_loc = f"{wds_dir}/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically

df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)
# some tests, if you like
# webdataset_utils.load_wds_directly(save_loc)
# webdataset_utils.load_wds_with_augmentation(save_loc)
# webdataset_utils.load_wds_with_webdatamodule([save_loc], label_cols)
# df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=overwrite)
# # some tests, if you like
# # webdataset_utils.load_wds_directly(save_loc)
# # webdataset_utils.load_wds_with_augmentation(save_loc)
# # webdataset_utils.load_wds_with_webdatamodule([save_loc], label_cols)


def make_mock_wds(save_dir: str, label_cols: List, n_shards: int, shard_size: int):
Expand Down Expand Up @@ -136,7 +136,7 @@ def galaxy_to_wds(galaxy: pd.Series, label_cols, transform=None):
if transform is not None:
im = transform(image=im)['image']

labels = json.dumps(galaxy[label_cols].astype(np.int32).to_dict())
labels = json.dumps(galaxy[label_cols].to_dict())
id_str = str(galaxy['id_str'])
return {
"__key__": id_str,
Expand Down
19 changes: 0 additions & 19 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,22 +470,3 @@ def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_drop
def schema_to_campaigns(schema):
# e.g. [gz2, dr12, ...]
return [question.text.split('-')[-1] for question in schema.questions]


# class ToyEncoder(pl.LightningModule):

# def __init__(self):
# super(ToyEncoder, self).__init__()

# self.conv1 = nn.Conv2d(3, 6, 5)
# self.pool = nn.MaxPool2d(2, 2)
# self.conv2 = nn.Conv2d(6, 16, 5)
# # pool again
# self.fc1 = nn.Linear(16 * 5 * 5, 1280) # dim 1280, like effnetb0

# def forward(self, x):
# x = self.pool(nn.functional.relu(self.conv1(x)))
# x = self.pool(nn.functional.relu(self.conv2(x)))
# x = x.view(-1, 16 * 5 * 5)
# x = nn.functional.relu(self.fc1(x))
# return x

0 comments on commit d94150e

Please sign in to comment.