New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v2 BUG]: Cross entropy loss is too big #845
Comments
specifically, the problem starts here: chemprop/chemprop/nn/predictors.py Lines 256 to 283 in cc5b3c1
and continues to here: chemprop/chemprop/nn/predictors.py Lines 108 to 131 in cc5b3c1
IMO it will probably be easier to fix this by abstracting the multiclass FFN into its own |
chemprop.nn.MulticlassClassificationFFN
setsn_tasks = n_tasks * n_classes
. This is a problem because then the default task weights istask_weights = torch.ones(n_tasks)
.F.cross_entropy
inCrossEntropyLoss
will reduce then_tasks * n_classes
predicted values ton_tasks
loss values, but thenL = L * self.task_weights.view(1, -1)
will broadcast then_tasks
loss values ton_tasks * n_classes
, effectively multiplying the loss value byn_classes
when we useL.sum()
.This doesn't break anything so I'll plan to fix it when we take another look at loss functions and metrics.
The text was updated successfully, but these errors were encountered: