You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File "C:\Users\JJH\anaconda3\envs\mmcv\lib\site-packages\mmengine\optim\optimizer\optimizer_wrapper.py", line 220, in backward
loss.backward(**kwargs)
When customizing a new loss function, this error occurs
The defined loss is:
Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from typing import Optional, Union
from .utils import get_class_weight
import torch.nn as nn
from torch import Tensor
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
import torch
def ces_loss(pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
reduction: Union[str, None] = 'none',
avg_factor: Optional[int] = None,
ignore_index: Optional[int] = 255,
alpha=0.5,
gamma=2) -> Tensor:
n, c, h, w = pred.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
pred = F.interpolate(pred, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = pred.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
logpt = -nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none')(temp_inputs, temp_target)
pt = torch.exp(logpt)
if alpha is not None:
logpt *= alpha
loss = -((1 - pt) ** gamma) * logpt
loss = loss.mean()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class MyLoss(nn.Module):
File "C:\Users\JJH\anaconda3\envs\mmcv\lib\site-packages\mmengine\optim\optimizer\optimizer_wrapper.py", line 220, in backward
loss.backward(**kwargs)
When customizing a new loss function, this error occurs
The defined loss is:
Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from typing import Optional, Union
from .utils import get_class_weight
import torch.nn as nn
from torch import Tensor
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
import torch
def ces_loss(pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
reduction: Union[str, None] = 'none',
avg_factor: Optional[int] = None,
ignore_index: Optional[int] = 255,
alpha=0.5,
gamma=2) -> Tensor:
n, c, h, w = pred.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
pred = F.interpolate(pred, size=(ht, wt), mode="bilinear", align_corners=True)
@MODELS.register_module()
class MyLoss(nn.Module):
The text was updated successfully, but these errors were encountered: