Skip to content

Commit

Permalink
Trial for BCE los
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Mar 14, 2023
1 parent 628f283 commit 63defef
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
3 changes: 2 additions & 1 deletion autoPyTorch/datasets/base_dataset.py
Expand Up @@ -137,7 +137,8 @@ def __init__(
self.output_type: str = type_of_target(self.train_tensors[1])

if STRING_TO_OUTPUT_TYPES[self.output_type] in CLASSIFICATION_OUTPUTS:
self.output_shape = len(np.unique(self.train_tensors[1]))
n_classes = len(np.unique(self.train_tensors[1]))
self.output_shape = n_classes if n_classes > 2 else 1
else:
self.output_shape = self.train_tensors[1].shape[-1] if self.train_tensors[1].ndim > 1 else 1

Expand Down
Expand Up @@ -61,7 +61,10 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
self.to(self.device)

if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS:
self.final_activation = nn.Softmax(dim=1)
if X['dataset_properties']['output_shape'] > 1:
self.final_activation = nn.Softmax(dim=1)
else:
self.final_activation = nn.Sigmoid()

self.is_fitted_ = True

Expand Down
4 changes: 2 additions & 2 deletions autoPyTorch/pipeline/components/training/losses.py
Expand Up @@ -26,7 +26,7 @@

losses = dict(classification=dict(
CrossEntropyLoss=dict(
module=CrossEntropyLoss, supported_output_types=[MULTICLASS, BINARY]),
module=CrossEntropyLoss, supported_output_types=[MULTICLASS]),
BCEWithLogitsLoss=dict(
module=BCEWithLogitsLoss, supported_output_types=[BINARY])),
regression=dict(
Expand Down Expand Up @@ -110,6 +110,6 @@ def get_loss(dataset_properties: Dict[str, Any], name: Optional[str] = None) ->
else:
loss = supported_losses[name]
else:
loss = get_default(task)
loss = list(supported_losses.values())[0] # get_default(task)

return loss
15 changes: 14 additions & 1 deletion autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Expand Up @@ -415,7 +415,14 @@ def cast_targets(self, targets: torch.Tensor) -> torch.Tensor:
if targets.ndim == 1:
targets = targets.unsqueeze(1)
else:
targets = targets.long().to(self.device)
# make sure that targets will have same shape as outputs (really important for mse loss for example)
if targets.ndim == 1:
targets = targets.unsqueeze(1)
# BCE requires target to be float.
targets = targets.float().to(self.device)
else:
targets = targets.long().to(self.device)

return targets

def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torch.Tensor]:
Expand All @@ -440,6 +447,9 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
self.optimizer.zero_grad()
outputs = self.model(data)
loss_func = self.criterion_preparation(**criterion_kwargs)
if len(outputs.size()) == 1:
outputs = outputs.unsqueeze(1)

loss = loss_func(self.criterion, outputs)
loss.backward()
self.optimizer.step()
Expand Down Expand Up @@ -476,6 +486,9 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,

outputs = self.model(data)

if len(outputs.size()) == 1:
outputs = outputs.unsqueeze(1)

loss = self.criterion(outputs, targets)
loss_sum += loss.item() * batch_size
N += batch_size
Expand Down
Expand Up @@ -340,7 +340,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
numerical_columns=X['dataset_properties']['numerical_columns'] if 'numerical_columns' in X[
'dataset_properties'] else None
)
self.logger.debug(f"Running on device: {get_device_from_fit_dictionary(X)}")
self.logger.debug(f"Running on device: {get_device_from_fit_dictionary(X)} with loss function: {self.choice.criterion}")
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
self.run_summary = RunSummary(
total_parameter_count,
Expand Down

0 comments on commit 63defef

Please sign in to comment.