Skip to content

Commit

Permalink
carefully start adding back
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 2, 2024
1 parent 3145542 commit f99dfd0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def configure_optimizers(self):
logging.info('blocks that will be tuned: {}'.format(self.n_blocks))
blocks_to_tune = tuneable_blocks[:self.n_blocks]
# optionally, can finetune batchnorm params in remaining layers
# remaining_blocks = tuneable_blocks[self.n_blocks:]
# logging.info('Remaining blocks: {}'.format(len(remaining_blocks)))
# assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining'
remaining_blocks = tuneable_blocks[self.n_blocks:]
logging.info('Remaining blocks: {}'.format(len(remaining_blocks)))
assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining'

# Append parameters of layers for finetuning along with decayed learning rate
for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3'
Expand All @@ -232,17 +232,17 @@ def configure_optimizers(self):
})

# optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers
# for i, block in enumerate(remaining_blocks):
# if self.always_train_batchnorm:
# raise NotImplementedError
for i, block in enumerate(remaining_blocks):
if self.always_train_batchnorm:
raise NotImplementedError
# _, block_batch_norm_params = get_batch_norm_params_lighting(block)
# params.append({
# "params": block_batch_norm_params,
# "lr": lr * (self.lr_decay**i)
# })


# logging.info('param groups: {}'.format(len(params)))
logging.info('param groups: {}'.format(len(params)))
# for param_group_n, param_group in enumerate(params):
# shapes_within_param_group = [p.shape for p in list(param_group['params'])]
# logging.debug('param group {}: {}'.format(param_group_n, shapes_within_param_group))
Expand Down

0 comments on commit f99dfd0

Please sign in to comment.