Skip to content

Commit

Permalink
unit interval
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 26, 2024
1 parent 164448c commit f48f5ec
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,15 +421,24 @@ class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract):

def __init__(
self,
unit_interval=False,
**super_kwargs) -> None:

super().__init__(**super_kwargs)

self.unit_interval = unit_interval
if self.unit_interval:
logging.info('unit_interval=True, using sigmoid activation for finetunng head')
head_activation = torch.nn.functional.sigmoid
else:
head_activation = None

logging.info('Using classification head and cross-entropy loss')
self.head = LinearHead(
input_dim=self.encoder_dim,
output_dim=1,
dropout_prob=self.dropout_prob
dropout_prob=self.dropout_prob,
activation=head_activation
)
self.loss = mse_loss
# rmse metrics. loss is mse already.
Expand Down Expand Up @@ -540,17 +549,21 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx):

# https://github.com/inigoval/byol/blob/1da1bba7dc5cabe2b47956f9d7c6277decd16cc7/byol_main/networks/models.py#L29
class LinearHead(torch.nn.Module):
def __init__(self, input_dim, output_dim, dropout_prob=0.5):
def __init__(self, input_dim, output_dim, dropout_prob=0.5, activation=None):
# input dim is representation dim, output_dim is num classes
super(LinearHead, self).__init__()
self.output_dim = output_dim
self.dropout = torch.nn.Dropout(p=dropout_prob)
self.linear = torch.nn.Linear(input_dim, output_dim)
if activation is not None:
self.activation = activation

def forward(self, x):
# returns logits, as recommended for CrossEntropy loss
x = self.dropout(x)
x = self.linear(x)
if self.activation is not None:
x = self.activation(x)
if self.output_dim == 1:
return x.squeeze()
else:
Expand Down

0 comments on commit f48f5ec

Please sign in to comment.