diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 179d871fd1..fd53afe22d 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -8,9 +8,9 @@ from mmengine.model import BaseModule from torch import Tensor +from mmseg.registry import MODELS from mmseg.structures import build_pixel_sampler from mmseg.utils import ConfigType, SampleList -from ..builder import build_loss from ..losses import accuracy from ..utils import resize @@ -140,11 +140,11 @@ def __init__(self, self.threshold = threshold if isinstance(loss_decode, dict): - self.loss_decode = build_loss(loss_decode) + self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: - self.loss_decode.append(build_loss(loss)) + self.loss_decode.append(MODELS.build(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}') diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py index ef48fb6995..2bba73b301 100644 --- a/mmseg/models/decode_heads/enc_head.py +++ b/mmseg/models/decode_heads/enc_head.py @@ -9,7 +9,6 @@ from mmseg.registry import MODELS from mmseg.utils import ConfigType, SampleList -from ..builder import build_loss from ..utils import Encoding, resize from .decode_head import BaseDecodeHead @@ -128,7 +127,7 @@ def __init__(self, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.use_se_loss: - self.loss_se_decode = build_loss(loss_se_decode) + self.loss_se_decode = MODELS.build(loss_se_decode) self.se_layer = nn.Linear(self.channels, self.num_classes) def forward(self, inputs): diff --git a/mmseg/models/decode_heads/vpd_depth_head.py b/mmseg/models/decode_heads/vpd_depth_head.py index 0c54c2da1b..65bdfbd8d9 100644 --- a/mmseg/models/decode_heads/vpd_depth_head.py +++ b/mmseg/models/decode_heads/vpd_depth_head.py @@ -10,7 +10,6 @@ from mmseg.registry import MODELS from mmseg.utils import SampleList -from ..builder import build_loss from ..utils import resize from .decode_head import BaseDecodeHead @@ -184,11 +183,11 @@ def __init__( # build loss if isinstance(loss_decode, dict): - self.loss_decode = build_loss(loss_decode) + self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: - self.loss_decode.append(build_loss(loss)) + self.loss_decode.append(MODELS.build(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}')