New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264
base: dev-1.x
Are you sure you want to change the base?
Changes from 6 commits
15f31ae
8eb061f
77661b1
0b363ab
4e0c5d4
cbbd55e
c326b40
0a7e01d
3bff55b
ea15ffa
49ba0b2
39d8da3
c14b152
d7a4bd3
1fe9521
cdc31bc
58c4881
52f266f
73d18de
2304c6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
data_preprocessor = dict( | ||
type='SegDataPreProcessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_val=0, | ||
seg_pad_val=255) | ||
model = dict( | ||
type='EncoderDecoder', | ||
data_preprocessor=data_preprocessor, | ||
pretrained=None, | ||
backbone=dict( | ||
type='HRNet', | ||
norm_cfg=norm_cfg, | ||
norm_eval=False, | ||
extra=dict( | ||
stage1=dict( | ||
num_modules=1, | ||
num_branches=1, | ||
block='BOTTLENECK', | ||
num_blocks=(4, ), | ||
num_channels=(64, )), | ||
stage2=dict( | ||
num_modules=1, | ||
num_branches=2, | ||
block='BASIC', | ||
num_blocks=(4, 4), | ||
num_channels=(18, 36)), | ||
stage3=dict( | ||
num_modules=4, | ||
num_branches=3, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4), | ||
num_channels=(18, 36, 72)), | ||
stage4=dict( | ||
num_modules=3, | ||
num_branches=4, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4, 4), | ||
num_channels=(18, 36, 72, 144)))), | ||
decode_head=dict( | ||
type='HRNetContrastHead', | ||
in_channels=[18, 36, 72, 144], | ||
in_index=(0, 1, 2, 3), | ||
channels=sum([18, 36, 72, 144]), | ||
input_transform='resize_concat', | ||
proj_n=256, | ||
proj_mode='convmlp', | ||
drop_p=0.1, | ||
dropout_ratio=-1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='PixelContrastCrossEntropyLoss', | ||
base_temperature=0.07, | ||
temperature=0.1)), | ||
|
||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = [ | ||
'../_base_/models/fcn_hrcontrast18.py', '../_base_/datasets/vaihingen.py', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might test it with the Cityscapes dataset. |
||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' | ||
] | ||
crop_size = (512, 512) | ||
data_preprocessor = dict(size=crop_size) | ||
model = dict( | ||
data_preprocessor=data_preprocessor, decode_head=dict(num_classes=6)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# Originally from https://github.com/visual-attention-network/segnext | ||
# Licensed under the Apache License, Version 2.0 (the "License") | ||
from typing import List, Tuple | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from mmcv.cnn import ConvModule | ||
from torch import Tensor | ||
|
||
from mmseg.registry import MODELS | ||
from mmseg.utils import ConfigType, SampleList | ||
from ..losses import accuracy | ||
from .decode_head import BaseDecodeHead | ||
|
||
|
||
class ProjectionHead(nn.Module): | ||
"""The projection head used by contrast learning. | ||
|
||
Args: | ||
dim_in (int): | ||
The dimensions of input features. | ||
proj_dim (int, optional): | ||
The output dimensions of projection head. Default: 256. | ||
proj (str, optional): The type of projection head, | ||
only support 'linear' and 'convmlp'. Default: 'convmlp'. | ||
""" | ||
|
||
def __init__(self, in_channels, out_channels=256, proj='convmlp'): | ||
super().__init__() | ||
if proj == 'linear': | ||
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1) | ||
elif proj == 'convmlp': | ||
self.proj = nn.Sequential( | ||
ConvModule(in_channels, in_channels, kernel_size=1), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1), | ||
) | ||
else: | ||
raise KeyError("The type of project head only support 'linear' \ | ||
and 'convmlp', but got {}.".format(proj)) | ||
|
||
def forward(self, x): | ||
return F.normalize(self.proj(x), p=2.0, dim=1) | ||
|
||
|
||
class SegmentationHead(nn.Module): | ||
"""The segmentation head used by contrast learning. | ||
|
||
Args: | ||
dim_in (int): | ||
The dimensions of input features. | ||
proj_dim (int, optional): | ||
The output dimensions of projection head. Default: 256. | ||
proj (str, optional): | ||
The type of projection head, | ||
only support 'linear' and 'convmlp'. Default: 'convmlp'. | ||
""" | ||
|
||
def __init__(self, in_channels, out_channels=19, drop_prob=0.1): | ||
super().__init__() | ||
|
||
self.seg = nn.Sequential( | ||
ConvModule(in_channels, in_channels, kernel_size=1), | ||
nn.Dropout2d(drop_prob), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1), | ||
) | ||
|
||
def forward(self, x): | ||
return self.seg(x) | ||
|
||
|
||
@MODELS.register_module() | ||
class HRNetContrastHead(BaseDecodeHead): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might rename it to |
||
"""The segmentation head used by contrast learning. | ||
|
||
Args: | ||
drop_p (float): | ||
The probability of dropout in segment head. | ||
proj_n (int): | ||
Each pixel will be projected into a vector with length of proj_n. | ||
proj_mode (str): | ||
The mode for project head ,'linear' or 'convmlp'. | ||
""" | ||
|
||
def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs): | ||
super().__init__(**kwargs) | ||
if proj_n <= 0: | ||
raise KeyError('proj_n must >0') | ||
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float): | ||
raise KeyError('drop_p must be a float >=0') | ||
self.proj_n = proj_n | ||
|
||
self.seghead = SegmentationHead( | ||
in_channels=self.in_channels, | ||
out_channels=self.num_classes, | ||
drop_prob=drop_p) | ||
self.projhead = ProjectionHead( | ||
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode) | ||
|
||
def forward(self, inputs): | ||
inputs = self._transform_inputs(inputs) | ||
output = {} | ||
output['seg'] = self.seghead(inputs) | ||
|
||
output['proj'] = self.projhead(inputs) | ||
|
||
return output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output should be a list. |
||
|
||
def loss_by_feat(self, seg_logits: dict, | ||
batch_data_samples: SampleList) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The seg_logits should be a list. |
||
"""Compute segmentation loss. | ||
|
||
Args: | ||
seg_logits (dict): The output from decode head forward function. | ||
batch_data_samples (List[:obj:`SegDataSample`]): The seg | ||
data samples. It usually includes information such | ||
as `metainfo` and `gt_sem_seg`. | ||
|
||
Returns: | ||
dict[str, Tensor]: a dictionary of loss components | ||
""" | ||
|
||
seg_label = self._stack_batch_gt(batch_data_samples) | ||
loss = dict() | ||
|
||
if self.sampler is not None: | ||
seg_weight = self.sampler.sample(seg_logits['seg'], seg_label) | ||
else: | ||
seg_weight = None | ||
seg_label = seg_label.squeeze(1) | ||
|
||
if not isinstance(self.loss_decode, nn.ModuleList): | ||
losses_decode = [self.loss_decode] | ||
else: | ||
losses_decode = self.loss_decode | ||
for loss_decode in losses_decode: | ||
if loss_decode.loss_name not in loss: | ||
loss[loss_decode.loss_name] = loss_decode( | ||
seg_logits, | ||
seg_label, | ||
weight=seg_weight, | ||
ignore_index=self.ignore_index) | ||
else: | ||
loss[loss_decode.loss_name] += loss_decode( | ||
seg_logits, | ||
seg_label, | ||
weight=seg_weight, | ||
ignore_index=self.ignore_index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the implementation of PaddleSeg, the input of different loss methods should be different. |
||
|
||
loss['acc_seg'] = accuracy( | ||
F.interpolate(seg_logits['seg'], seg_label.shape[1:]), | ||
seg_label, | ||
ignore_index=self.ignore_index) | ||
return loss | ||
|
||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], | ||
test_cfg: ConfigType) -> Tensor: | ||
"""Forward function for prediction. | ||
|
||
Args: | ||
inputs (Tuple[Tensor]): List of multi-level img features. | ||
batch_img_metas (dict): List Image info where each dict may also | ||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path', | ||
'ori_shape', and 'pad_shape'. | ||
For details on the values of these keys see | ||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`. | ||
test_cfg (dict): The testing config. | ||
|
||
Returns: | ||
Tensor: Outputs segmentation logits map. | ||
""" | ||
seg_logits = self.forward(inputs)['seg'] | ||
|
||
return self.predict_by_feat(seg_logits, batch_img_metas) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the reference implementation, it should add CrossEntroyloss. https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.8/configs/hrnet_w48_contrast/HRNet_W48_contrast_cityscapes_1024x512_60k.yml#L13