Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch][Feature Request] Label Smoothing for CrossEntropyLoss #7455

Closed
kaiyuyue opened this issue May 10, 2018 · 48 comments
Closed

[PyTorch][Feature Request] Label Smoothing for CrossEntropyLoss #7455

kaiyuyue opened this issue May 10, 2018 · 48 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects

Comments

@kaiyuyue
Copy link
Contributor

kaiyuyue commented May 10, 2018

Solved 🎉

Starting from v1.10.0, torch.nn.CrossEntropy()has an arg label_smoothing=0.0 - API link.


Hi, guys. The type torch.LongTensor of target will hinder the implementation like some methods in reference. So is there a possible to add a Arg: label_smoothing for torch.nn.CrossEntropyLoss(), or maybe simply add the docs to show how to convert the target into one-hot vector to work with torch.nn.CrossEntropyLoss() together, or any other simple ways? Thanks.

cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @mruberry

@whr94621
Copy link

@kaiyuyue
For label_smoothing, you cat look at the implementation of NJUNMT-pytorch

In the class NMTCritierion

https://github.com/whr94621/NJUNMT-pytorch/blob/aff968c0da9273dc42eabbb8ac4e459f9195f6e4/src/modules/criterions.py#L131

@mdraw
Copy link
Contributor

mdraw commented May 13, 2018

See https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580/5. The cross_entropy() function that's shown there should work with smoothed labels that have the same dimension as the network outputs.

I don't think CrossEntropyLoss() should directly support a label_smoothing option, since label smoothing can be done in many different ways and the smoothing itself can be easily done manually by the user. But I agree it should at least be mentioned in the docs how to deal with targets that can't be represented by scalar values, or add support for passing (k-hot/smoothed) targets to CrossEntropyLoss.

@zou3519 zou3519 added the todo Not as important as medium or high priority tasks, but we will work on these. label May 14, 2018
@Jiaming-Liu
Copy link
Contributor

Maybe we need sth like NonSparseCrossEntropy? (well.. it's hard to name it)

@PistonY
Copy link

PistonY commented Jul 19, 2019

Here is my implement

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

@cpuhrsch cpuhrsch added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed high priority labels Jul 19, 2019
@PistonY
Copy link

PistonY commented Jul 22, 2019

I agree with @mdraw
A good choice is do it in two step:

  1. Use a function to get smooth label
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
    return true_dist
  1. Make CrossEntropyLoss support k-hot/smoothed targets.

Then we can use it like

Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...

outputs = model(data)

smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
...

By the way I tested my implement on ImageNet, it looks good

model epochs dtype batch size* gpus lr tricks top1/top5 improve
resnet50 120 FP16 128 8 0.4 - 77.35/- baseline
resnet50 120 FP16 128 8 0.4 Lable smoothing 77.78/93.80 +0.43

@ezyang ezyang added module: nn Related to torch.nn and removed triage review labels Jul 22, 2019
@ezyang
Copy link
Contributor

ezyang commented Jul 22, 2019

I believe @zhangguanheng66 said that this is something he might be able to look at in the future.

@suanrong
Copy link

suanrong commented Jul 29, 2019

Just use torch.nn.KLDivLoss. It is the same.


Update: it is not the same.

@sadikneipp
Copy link

I believe this is similar to what the new Snorkel lib implemented:
https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.cross_entropy_with_probs.html

Just some extra info on how people are going around the issue

@Data-drone
Copy link

see https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 for how Nvidia do it that might help?

@steermomo
Copy link

steermomo commented Oct 18, 2019

@suanrong Thanks a lot.

====
And maybe this is helpful for others who read this issue

Note that cross-entropy for non 0/1 labels is not symmetric, which could be an explanation for the poor performance.
https://discuss.pytorch.org/t/cross-entropy-for-soft-label/16093/2

@huanglianghua
Copy link

Suggested implementation:

class LabelSmoothLoss(nn.Module):
    
    def __init__(self, smoothing=0.0):
        super(LabelSmoothLoss, self).__init__()
        self.smoothing = smoothing
    
    def forward(self, input, target):
        log_prob = F.log_softmax(input, dim=-1)
        weight = input.new_ones(input.size()) * \
            self.smoothing / (input.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

I have checked that:
(1) When smoothing=0.0, the output is the same as nn.CrossEntropyLoss within precision 1e-5.
(2) When smoothing>0.0, the sums of weights over different classes weight.sum(dim=-1) are always 1.

@hadaev8
Copy link

hadaev8 commented Mar 9, 2020

Implementations here lack of class weights feature.
((

@alshahrani2030
Copy link

Just use torch.nn.KLDivLoss. It is the same.

can you please elaborate more

@suanrong
Copy link

Just use torch.nn.KLDivLoss. It is the same.

can you please elaborate more

Assumed you already have smoothed label, you can just use torch.nn.KLDivLoss since the difference between them is entropy of the label and is a constant.

@jasstionzyf
Copy link

jasstionzyf commented May 21, 2020

@PistonY why not use this way much simple:

with torch.no_grad():
    confidence = 1.0 - smoothing_factor
    true_dist = torch.mul(labels, confidence)
    true_dist = torch.add(true_dist, smoothing_factor / (classNum - 1))
    print(true_dist)
return true_dist

@skull3r7
Copy link

Implementations here lack of class weights feature.

Can i multiply the class weights on the smoothed label tensor?

@jphdotam
Copy link

jphdotam commented Sep 15, 2020

def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

The problem with this implementation is it's very sensitive to the number of classes

Where n_classes is 2, any smoothing above 0.5 will reverse the labels, which I'm sure the person does not want; when n_classes is 3 it's any smoothing above 2/3, and 0.75 for 4 classes. So maybe:

assert 0 <= smoothing < (classes-1)/classes would catch this issue, but I feel the smoothing needs to take the number of classes into account?

@PistonY
Copy link

PistonY commented Sep 15, 2020

def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):

"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

The problem with this implementation is it's very sensitive to the number of classes

Where n_classes is 2, any smoothing above 0.5 will reverse the labels, which I'm sure the person does not want; when n_classes is 3 it's any smoothing above 2/3, and 0.75 for 4 classes. So maybe:

assert 0 <= smoothing < (classes-1)/classes would catch this issue, but I feel the smoothing needs to take the number of classes into account?

It's a wise idea I think.

@antrec
Copy link

antrec commented Oct 28, 2020

Thanks for the discussion. There are a few points that remain unclear and look like mistakes to me:

  • the weight tensor in @PistonY 's implementation
  • the equivalence between KL divergence and label-smoothing (@suanrong )

About the weights:

The label smoothing paper states y_k = smoothing / n_classes + (1 - smoothing) * y_{one hot}. So the value of the weight is smoothing / n_classes for indices other than the target, and it is smoothing / n_classes + (1 - smoothing) for the target class. Yet in @PistonY 's implementation, the function torch.scatter_ overwrites the value for the target to (1 - smoothing) (and the constant term disappears).
Moreover, I do not really understand why we use n_classes -= 1 in the computation (?)

About the equivalence between KL divergence and label-smoothing:

The label-smoothing cross-entropy loss reads, with y the weights mentioned above,

LS(x, y) = - sum_k {y[k] * log-prob(x)}
         = - sum_k {y[k] * log(exp(x[k]) / (sum_j exp(x[j])))}
         = - sum_k {y[k] * (x[k] - log-sum-exp(x))}
         = - sum_k {y[k] * x[k]} + log-sum-exp(x)

where the third to the fourth line uses the fact that sum_k y[k] = smoothing / n_classes * n_classes + (1 - smoothing) = 1.

The KL-divergence loss reads,

KL(x, y) = - sum_k {y[k] * x[k] - y[k] * log(y[k])
         = - sum_k {y[k] * x[k]} - sum_k {y[k] * log(y[k])}
         = - sum_k {y[k] * x[k]} - Const.

So in the end we have LS(x, y) = KL(x, y) + log-sum-exp(x) + Const., where Const. is the constant term corresponding to the entropy of y, which is indeed constant in multiclass settings. But what about the log-sum-exp term ?

I did a few computations using a custom cross entropy function accepting soft targets, and it shows that it is indeed equal to the KLDiv loss plus log-sum-exp, up to the constant term corresponding to the entropy of y. Is there any assumption on the logits that make it reasonable to drop this term ?

Thanks a lot for the clarifications.
Cheers !

@suanrong
Copy link

Thanks @antrec !

You are right. I ignored the logsoftmax function and made a mistake.

@zou3519
Copy link
Contributor

zou3519 commented Jan 22, 2021

Btw, this issue seems really similar to another popular issue: #11959. Maybe we can use the details of that to inform which option we wanted to take here? @jbschlosser

@garyhlai
Copy link

garyhlai commented Feb 12, 2021

criterion = nn.CrossEntropyLossWithProbs()
...
loss = criterion(output, F.smooth_labels(target, eps=0.1))

This code example you provided @jbschlosser would fall under option 1, right? That seems like the option most aligned with the single-responsibility principle. Given the prediction and target, CrossEntropyLossProbs() would output the loss and that's it - it doesn't smooth/change the target inside it.

The free-standing function F.smooth_labels also sounds sweet; however, is there any reason why we can't extend the existing nn.CrossEntropyLoss to support smoothed vectors as #11959 mentioned (was wondering what the current status is on that since someone last made a comment in May 2020 but received no response @zou3519)? Why do we need a new nn.CrossEntropyLossWithProbs class?

@hadaev8
Copy link

hadaev8 commented Feb 25, 2021

Since we here, should gamma from focal loss be added?

@lukasfolle
Copy link

Any progress on this issue?

@jbschlosser
Copy link
Contributor

This code example you provided @jbschlosser would fall under option 1, right? That seems like the option most aligned with the single-responsibility principle. Given the prediction and target, CrossEntropyLossProbs() would output the loss and that's it - it doesn't smooth/change the target inside it.

Agreed, it's nice and composable :)

The free-standing function F.smooth_labels also sounds sweet; however, is there any reason why we can't extend the existing nn.CrossEntropyLoss to support smoothed vectors as #11959 mentioned (was wondering what the current status is on that since someone last made a comment in May 2020 but received no response @zou3519)? Why do we need a new nn.CrossEntropyLossWithProbs class?

It's possible to support soft labels directly in nn.CrossEntropyLoss / nn.NLLLoss. Note that each of these would have to support 3 target types- hard labels of shape N, hard labels of shape (N, d1, d2, ..., dk), and soft labels of shape (N, C). The (N, C) and (N, d1) for K=1 cases could be disambiguated only by checking the target dtype. Switching behavior based solely on the dtype of the target would not follow PyTorch precedent and could easily result in unexpected behavior / confusion. Example: accidentally passing the same labels as floating-point instead of integral could invisibly cause a non-negligible performance drop. This sort of specialization also wouldn't play nice with FX.

To get around this, one suggestion from @zou3519 is to add a soft flag (default False) to nn.CrossEntropyLoss / nn.NLLLoss. Setting soft=True would explicitly indicate that soft labels are desired, addressing the above issues without needing e.g. a new nn.CrossEntropyLossWithProbs class.

@thomasjpfan thomasjpfan added this to Needs Triage in torch.nn via automation Jun 2, 2021
@thomasjpfan thomasjpfan moved this from Needs Triage to In Discussion in torch.nn Jun 2, 2021
@itsnamgyu
Copy link

Here is my implement

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

Huge thanks to @PistonY for the example code. Just one tidbit: in label smoothing, epsilon is canonically distributed among all classes, not just the non-target classes. Here is a minor fix that follows the canonical version:

class LabelSmoothingLossCanonical(nn.Module):
    def __init__(self, smoothing=0.0, dim=-1):
        super(LabelSmoothingLossCanonical, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
            true_dist += self.smoothing / pred.size(self.dim)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

Test Code

pred = torch.tensor([
    [0.7, 0.1, 0.1, 0.1],
    [0.1, 0.7, 0.1, 0.1],
    [0.1, 0.1, 0.7, 0.1],
    [0.1, 0.1, 0.1, 0.7],
])
target = torch.Tensor([
    0, 1, 2, 3
]).long()

loss = LabelSmoothingLoss(classes=4, smoothing=0.1)
print(loss(pred, target))

loss = LabelSmoothingLossCanonical(smoothing=0.1)
print(loss(pred, target))

Test Output

tensor(1.0332)
tensor(1.0182)

@PistonY
Copy link

PistonY commented Jun 28, 2021

@itsnamgyu Thanks to fix my version, I'm glad to test it if possible.
@antrec Sorry for late response, quite busy. I checked the tf official implement and I really miss something in this.

@aiot-tech
Copy link

aiot-tech commented Jul 8, 2021

amazing! Why not write it out!

@jbschlosser
Copy link
Contributor

Update: I propose a 2-part solution to this issue to get both performance and flexibility:

  1. Support soft labels for cross-entropy loss (see [feature request] Support soft target distribution in cross entropy loss #11959) - allows for arbitrary label smoothing techniques like mixup and cutmix
  2. Support label_smoothing=0.0 arg in current CrossEntropyLoss - provides performant canonical label smoothing in terms of existing loss as done in [PyTorch][Feature Request] Label Smoothing for CrossEntropyLoss #7455 (comment)

Part 1 is in progress; check out the discussion at #11959 (comment) if interested.

@vigi30
Copy link

vigi30 commented Aug 5, 2021

Added a ignore index to the solution suggested by [wangleioffical]#7455 (comment)

def linear_combination(x, y, epsilon): 
    return epsilon*x + (1-epsilon)*y

def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss



#  Implementation of Label smoothing with CrossEntropy and ignore_index
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon:float=0.1, reduction='mean',ignore_index=-100):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction
        self.ignore_index = ignore_index
    def forward(self, preds, target):
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction,ignore_index=self.ignore_index)
        return linear_combination(loss/n, nll, self.epsilon)
 # Implementation of Label smoothing with NLLLoss and ignore_index

class LabelSmoothingNLLLoss(nn.Module):
    def __init__(self, epsilon:float=0.1, reduction='mean',ignore_index=-100):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction
        self.ignore_index = ignore_index
    def forward(self, preds, target):
        n = preds.size()[-1]
        loss = reduce_loss(-preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(preds, target, reduction=self.reduction,ignore_index=self.ignore_index)
        return linear_combination(loss/n, nll, self.epsilon)

@QuanticDisaster
Copy link

Hello, I see that this issue is closed, but can't find it in the torch documentation. The CrossEntropyLoss I want to use with weights keeps saying that it can't accept non integer values. Could someone point me to the soft-label equivalent in the documentation ?

@seyeeet
Copy link

seyeeet commented Sep 23, 2021

@itsnamgyu what does Canonical means in your reply?

@itsnamgyu
Copy link

@itsnamgyu what does Canonical means in your reply?

The canonical version follows the exact details used in the original paper, InceptionV2: https://arxiv.org/pdf/1512.00567.pdf

@seyeeet
Copy link

seyeeet commented Oct 6, 2021

would it be possible to have label smoothing CE loss for segmentation task?

@jbschlosser
Copy link
Contributor

@QuanticDisaster Label smoothing is available for nn.CrossEntropyLoss in PyTorch 1.10 (see docs here).

@jbschlosser
Copy link
Contributor

would it be possible to have label smoothing CE loss for segmentation task?

Hey @seyeeet, great question - do you mind opening a separate issue for discussion to help us determine the popularity / impact of that request?

@arisliang
Copy link

arisliang commented Dec 3, 2021

Thanks for a lot of good ideas and pointers. What is the conclusion for this request? Since it is closed.

@seyeeet
Copy link

seyeeet commented Dec 3, 2021

@arisliang #67863

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
torch.nn
  
Done
Development

Successfully merging a pull request may close this issue.