-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch][Feature Request] Label Smoothing for CrossEntropyLoss #7455
Comments
@kaiyuyue In the class |
See https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580/5. The I don't think |
Maybe we need sth like |
Here is my implement class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) |
I agree with @mdraw
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
true_dist = torch.empty(size=label_shape, device=true_labels.device)
true_dist.fill_(smoothing / (classes - 1))
true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist
Then we can use it like Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...
outputs = model(data)
smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
... By the way I tested my implement on ImageNet, it looks good
|
I believe @zhangguanheng66 said that this is something he might be able to look at in the future. |
Just use torch.nn.KLDivLoss. It is the same. Update: it is not the same. |
I believe this is similar to what the new Snorkel lib implemented: Just some extra info on how people are going around the issue |
see https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 for how Nvidia do it that might help? |
@suanrong Thanks a lot. ====
|
Suggested implementation: class LabelSmoothLoss(nn.Module):
def __init__(self, smoothing=0.0):
super(LabelSmoothLoss, self).__init__()
self.smoothing = smoothing
def forward(self, input, target):
log_prob = F.log_softmax(input, dim=-1)
weight = input.new_ones(input.size()) * \
self.smoothing / (input.size(-1) - 1.)
weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
loss = (-weight * log_prob).sum(dim=-1).mean()
return loss I have checked that: |
Implementations here lack of class weights feature. |
can you please elaborate more |
Assumed you already have smoothed label, you can just use torch.nn.KLDivLoss since the difference between them is entropy of the label and is a constant. |
@PistonY why not use this way much simple:
|
Can i multiply the class weights on the smoothed label tensor? |
The problem with this implementation is it's very sensitive to the number of classes Where n_classes is 2, any smoothing above 0.5 will reverse the labels, which I'm sure the person does not want; when n_classes is 3 it's any smoothing above 2/3, and 0.75 for 4 classes. So maybe:
|
It's a wise idea I think. |
Thanks for the discussion. There are a few points that remain unclear and look like mistakes to me:
About the weights:The label smoothing paper states About the equivalence between KL divergence and label-smoothing:The label-smoothing cross-entropy loss reads, with
where the third to the fourth line uses the fact that The KL-divergence loss reads,
So in the end we have I did a few computations using a custom cross entropy function accepting soft targets, and it shows that it is indeed equal to the Thanks a lot for the clarifications. |
Thanks @antrec ! You are right. I ignored the logsoftmax function and made a mistake. |
Btw, this issue seems really similar to another popular issue: #11959. Maybe we can use the details of that to inform which option we wanted to take here? @jbschlosser |
This code example you provided @jbschlosser would fall under option 1, right? That seems like the option most aligned with the single-responsibility principle. Given the prediction and target, CrossEntropyLossProbs() would output the loss and that's it - it doesn't smooth/change the target inside it. The free-standing function |
Since we here, should gamma from focal loss be added? |
Any progress on this issue? |
Agreed, it's nice and composable :)
It's possible to support soft labels directly in To get around this, one suggestion from @zou3519 is to add a |
Huge thanks to @PistonY for the example code. Just one tidbit: in label smoothing, epsilon is canonically distributed among all classes, not just the non-target classes. Here is a minor fix that follows the canonical version: class LabelSmoothingLossCanonical(nn.Module):
def __init__(self, smoothing=0.0, dim=-1):
super(LabelSmoothingLossCanonical, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
true_dist += self.smoothing / pred.size(self.dim)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) Test Codepred = torch.tensor([
[0.7, 0.1, 0.1, 0.1],
[0.1, 0.7, 0.1, 0.1],
[0.1, 0.1, 0.7, 0.1],
[0.1, 0.1, 0.1, 0.7],
])
target = torch.Tensor([
0, 1, 2, 3
]).long()
loss = LabelSmoothingLoss(classes=4, smoothing=0.1)
print(loss(pred, target))
loss = LabelSmoothingLossCanonical(smoothing=0.1)
print(loss(pred, target)) Test Outputtensor(1.0332)
tensor(1.0182) |
@itsnamgyu Thanks to fix my version, I'm glad to test it if possible. |
amazing! Why not write it out! |
Update: I propose a 2-part solution to this issue to get both performance and flexibility:
Part 1 is in progress; check out the discussion at #11959 (comment) if interested. |
Added a ignore index to the solution suggested by [wangleioffical]#7455 (comment)
|
Hello, I see that this issue is closed, but can't find it in the torch documentation. The CrossEntropyLoss I want to use with weights keeps saying that it can't accept non integer values. Could someone point me to the soft-label equivalent in the documentation ? |
@itsnamgyu what does |
The canonical version follows the exact details used in the original paper, InceptionV2: https://arxiv.org/pdf/1512.00567.pdf |
would it be possible to have label smoothing CE loss for segmentation task? |
@QuanticDisaster Label smoothing is available for |
Hey @seyeeet, great question - do you mind opening a separate issue for discussion to help us determine the popularity / impact of that request? |
Thanks for a lot of good ideas and pointers. What is the conclusion for this request? Since it is closed. |
Solved 🎉
Starting from v1.10.0,
torch.nn.CrossEntropy()
has an arglabel_smoothing=0.0
- API link.Hi, guys. The type
torch.LongTensor
of target will hinder the implementation like some methods in reference. So is there a possible to add aArg: label_smoothing
fortorch.nn.CrossEntropyLoss()
, or maybe simply add the docs to show how to convert thetarget
intoone-hot vector
to work withtorch.nn.CrossEntropyLoss()
together, or any other simple ways? Thanks.cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @mruberry
The text was updated successfully, but these errors were encountered: