Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264

Open
wants to merge 20 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 63 additions & 0 deletions configs/_base_/models/fcn_hrcontrast18.py
@@ -0,0 +1,63 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='HRNet',
norm_cfg=norm_cfg,
norm_eval=False,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(18, 36)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(18, 36, 72)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(18, 36, 72, 144)))),
decode_head=dict(
type='HRNetContrastHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
channels=sum([18, 36, 72, 144]),
input_transform='resize_concat',
proj_n=256,
proj_mode='convmlp',
drop_p=0.1,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='PixelContrastCrossEntropyLoss',
base_temperature=0.07,
temperature=0.1)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
8 changes: 8 additions & 0 deletions configs/hrnet/fcn_hrcontrast18_4xb4-80k_vaihingen-512x512.py
@@ -0,0 +1,8 @@
_base_ = [
'../_base_/models/fcn_hrcontrast18.py', '../_base_/datasets/vaihingen.py',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor, decode_head=dict(num_classes=6))
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Expand Up @@ -14,6 +14,7 @@
from .fpn_head import FPNHead
from .gc_head import GCHead
from .ham_head import LightHamHead
from .hrnetconstrast_head import HRNetContrastHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
Expand Down Expand Up @@ -42,5 +43,5 @@
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'LightHamHead', 'PIDHead', 'DDRHead'
'LightHamHead', 'PIDHead', 'DDRHead', 'HRNetContrastHead'
]
174 changes: 174 additions & 0 deletions mmseg/models/decode_heads/hrnetconstrast_head.py
@@ -0,0 +1,174 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
from typing import List, Tuple

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from torch import Tensor

from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList
from ..losses import accuracy
from .decode_head import BaseDecodeHead


class ProjectionHead(nn.Module):
"""The projection head used by contrast learning.

Args:
dim_in (int):
The dimensions of input features.
proj_dim (int, optional):
The output dimensions of projection head. Default: 256.
proj (str, optional): The type of projection head,
only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""

def __init__(self, in_channels, out_channels=256, proj='convmlp'):
super().__init__()
if proj == 'linear':
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
ConvModule(in_channels, in_channels, kernel_size=1),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
else:
raise KeyError("The type of project head only support 'linear' \
and 'convmlp', but got {}.".format(proj))

def forward(self, x):
return F.normalize(self.proj(x), p=2.0, dim=1)


class SegmentationHead(nn.Module):
"""The segmentation head used by contrast learning.

Args:
dim_in (int):
The dimensions of input features.
proj_dim (int, optional):
The output dimensions of projection head. Default: 256.
proj (str, optional):
The type of projection head,
only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""

def __init__(self, in_channels, out_channels=19, drop_prob=0.1):
super().__init__()

self.seg = nn.Sequential(
ConvModule(in_channels, in_channels, kernel_size=1),
nn.Dropout2d(drop_prob),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)

def forward(self, x):
return self.seg(x)


@MODELS.register_module()
class HRNetContrastHead(BaseDecodeHead):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might rename it to ContrastHead. Since the segmentation head could be from other methods.

"""The segmentation head used by contrast learning.

Args:
drop_p (float):
The probability of dropout in segment head.
proj_n (int):
Each pixel will be projected into a vector with length of proj_n.
proj_mode (str):
The mode for project head ,'linear' or 'convmlp'.
"""

def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs):
super().__init__(**kwargs)
if proj_n <= 0:
raise KeyError('proj_n must >0')
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float):
raise KeyError('drop_p must be a float >=0')
self.proj_n = proj_n

self.seghead = SegmentationHead(
in_channels=self.in_channels,
out_channels=self.num_classes,
drop_prob=drop_p)
self.projhead = ProjectionHead(
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode)

def forward(self, inputs):
inputs = self._transform_inputs(inputs)
output = {}
output['seg'] = self.seghead(inputs)

output['proj'] = self.projhead(inputs)

return output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output should be a list.


def loss_by_feat(self, seg_logits: dict,
batch_data_samples: SampleList) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seg_logits should be a list.

"""Compute segmentation loss.

Args:
seg_logits (dict): The output from decode head forward function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.

Returns:
dict[str, Tensor]: a dictionary of loss components
"""

seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()

if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits['seg'], seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)

if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the implementation of PaddleSeg, the input of different loss methods should be different.

https://github.com/PaddlePaddle/PaddleSeg/blob/2c8c35a8949fef74599f5ec557d340a14415f20d/paddleseg/core/train.py#L39


loss['acc_seg'] = accuracy(
F.interpolate(seg_logits['seg'], seg_label.shape[1:]),
seg_label,
ignore_index=self.ignore_index)
return loss

def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.

Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.

Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits = self.forward(inputs)['seg']

return self.predict_by_feat(seg_logits, batch_img_metas)
3 changes: 2 additions & 1 deletion mmseg/models/losses/__init__.py
Expand Up @@ -8,6 +8,7 @@
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
from .lovasz_loss import LovaszLoss
from .ohem_cross_entropy_loss import OhemCrossEntropy
from .pixel_contrast_cross_entropy_loss import PixelContrastCrossEntropyLoss
from .tversky_loss import TverskyLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss

Expand All @@ -16,5 +17,5 @@
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
'HuasdorffDisstanceLoss'
'HuasdorffDisstanceLoss', 'PixelContrastCrossEntropyLoss'
]