New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264
base: dev-1.x
Are you sure you want to change the base?
Changes from 12 commits
15f31ae
8eb061f
77661b1
0b363ab
4e0c5d4
cbbd55e
c326b40
0a7e01d
3bff55b
ea15ffa
49ba0b2
39d8da3
c14b152
d7a4bd3
1fe9521
cdc31bc
58c4881
52f266f
73d18de
2304c6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Pixel contrast cross entropy loss | ||
|
||
[Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf) | ||
|
||
## Description | ||
|
||
This is an implementation of **pixel contrast cross entropy loss** | ||
|
||
[Official Repo](https://github.com/tfzhou/ContrastiveSeg) | ||
|
||
## Abstract | ||
|
||
Current semantic segmentation methods focus only on mining “local” context, i.e., dependencies between pixels within individual images, by context-aggregation modules (e.g., dilated convolution, neural attention) or structureaware optimization criteria (e.g., IoU-like loss). However, they ignore “global” context of the training data, i.e., rich semantic relations between pixels across different images. Inspired by the recent advance in unsupervised contrastive representation learning, we propose a pixel-wise contrastive framework for semantic segmentation in the fully supervised setting. The core idea is to enforce pixel embeddings belonging to a same semantic class to be more similar than embeddings from different classes. It raises a pixel-wise metric learning paradigm for semantic segmentation, by explicitly exploring the structures of labeled pixels, which are long ignored in the field. Our method can be effortlessly incorporated into existing segmentation frameworks without extra overhead during testing. | ||
|
||
We experimentally show that, with famous segmentation models (i.e., DeepLabV3, HRNet, OCR) and backbones (i.e., ResNet, HRNet), our method brings consistent performance improvements across diverse datasets (i.e., Cityscapes, PASCALContext, COCO-Stuff). | ||
|
||
## Usage | ||
|
||
Here the configs for HRNet-W18 and HRNet-W48 with pixel_contrast_cross_entropy_loss on cityscapes dataset are provided. | ||
|
||
After putting Cityscapes dataset into "mmsegmentation/data/" dir, train the network by: | ||
|
||
```python | ||
python tools/train.py projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py | ||
``` | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@inproceedings{Wang_2021_ICCV, | ||
author = {Wang, Wenguan and Zhou, Tianfei and Yu, Fisher and Dai, Jifeng and Konukoglu, Ender and Van Gool, Luc}, | ||
title = {Exploring Cross-Image Pixel Contrast for Semantic Segmentation}, | ||
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | ||
year = {2021}, | ||
pages = {7303-7313} | ||
} | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .hrnetconstrast_head import ContrastHead | ||
from .pixel_contrast_cross_entropy_loss import PixelContrastCrossEntropyLoss | ||
|
||
__all__ = ['ContrastHead', 'PixelContrastCrossEntropyLoss'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# model settings | ||
|
||
custom_imports = dict(imports=['projects.pixel_contrast_cross_entropy_loss']) | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
data_preprocessor = dict( | ||
type='SegDataPreProcessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_val=0, | ||
seg_pad_val=255) | ||
model = dict( | ||
type='EncoderDecoder', | ||
data_preprocessor=data_preprocessor, | ||
pretrained=None, | ||
backbone=dict( | ||
type='HRNet', | ||
norm_cfg=norm_cfg, | ||
norm_eval=False, | ||
extra=dict( | ||
stage1=dict( | ||
num_modules=1, | ||
num_branches=1, | ||
block='BOTTLENECK', | ||
num_blocks=(4, ), | ||
num_channels=(64, )), | ||
stage2=dict( | ||
num_modules=1, | ||
num_branches=2, | ||
block='BASIC', | ||
num_blocks=(4, 4), | ||
num_channels=(18, 36)), | ||
stage3=dict( | ||
num_modules=4, | ||
num_branches=3, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4), | ||
num_channels=(18, 36, 72)), | ||
stage4=dict( | ||
num_modules=3, | ||
num_branches=4, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4, 4), | ||
num_channels=(18, 36, 72, 144)))), | ||
decode_head=dict( | ||
type='ContrastHead', | ||
in_channels=[18, 36, 72, 144], | ||
in_index=(0, 1, 2, 3), | ||
channels=sum([18, 36, 72, 144]), | ||
input_transform='resize_concat', | ||
proj_n=256, | ||
proj_mode='convmlp', | ||
drop_p=0.1, | ||
dropout_ratio=-1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=[ | ||
dict( | ||
type='PixelContrastCrossEntropyLoss', | ||
base_temperature=0.07, | ||
temperature=0.1, | ||
ignore_index=255, | ||
max_samples=1024, | ||
max_views=100, | ||
loss_weight=0.1), | ||
dict(type='CrossEntropyLoss', loss_weight=1.0) | ||
]), | ||
|
||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
_base_ = [ | ||
'./fcn_hrcontrast18.py', '../../../configs/_base_/datasets/cityscapes.py', | ||
'../../../configs/_base_/default_runtime.py', | ||
'../../../configs/_base_/schedules/schedule_40k.py' | ||
] | ||
data_root = 'data/cityscapes/' | ||
|
||
train_dataloader = dict(dataset=dict(data_root=data_root)) | ||
val_dataloader = dict(dataset=dict(data_root=data_root)) | ||
test_dataloader = dict(dataset=dict(data_root=data_root)) | ||
crop_size = (512, 1024) | ||
data_preprocessor = dict(size=crop_size) | ||
model = dict(data_preprocessor=data_preprocessor) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
_base_ = './fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py' | ||
model = dict( | ||
pretrained='open-mmlab://msra/hrnetv2_w48', | ||
backbone=dict( | ||
extra=dict( | ||
stage2=dict(num_channels=(48, 96)), | ||
stage3=dict(num_channels=(48, 96, 192)), | ||
stage4=dict(num_channels=(48, 96, 192, 384)))), | ||
decode_head=dict( | ||
in_channels=[48, 96, 192, 384], | ||
channels=sum([48, 96, 192, 384]), | ||
proj_n=720, | ||
drop_p=0.1)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
"""Modified from https://github.com/PaddlePaddle/PaddleSeg/ | ||
blob/2c8c35a8949fef74599f5ec557d340a14415f20d/ | ||
paddleseg/models/hrnet_contrast.py(Apache-2.0 License)""" | ||
|
||
from typing import List, Tuple | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from mmcv.cnn import ConvModule | ||
from torch import Tensor | ||
|
||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | ||
from mmseg.models.losses import accuracy | ||
from mmseg.registry import MODELS | ||
from mmseg.utils import ConfigType, SampleList | ||
|
||
|
||
class ProjectionHead(nn.Module): | ||
"""The projection head used by contrast learning. | ||
|
||
Args: | ||
dim_in (int): | ||
The dimensions of input features. | ||
proj_dim (int, optional): | ||
The output dimensions of projection head. Default: 256. | ||
proj (str, optional): The type of projection head, | ||
only support 'linear' and 'convmlp'. Default: 'convmlp'. | ||
""" | ||
|
||
def __init__(self, in_channels: int, out_channels=256, proj='convmlp'): | ||
super().__init__() | ||
if proj == 'linear': | ||
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1) | ||
elif proj == 'convmlp': | ||
self.proj = nn.Sequential( | ||
ConvModule(in_channels, in_channels, kernel_size=1), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1), | ||
) | ||
else: | ||
raise KeyError("The type of project head only support 'linear' \ | ||
and 'convmlp', but got {}.".format(proj)) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
return F.normalize(self.proj(x), p=2.0, dim=1) | ||
|
||
|
||
class SegmentationHead(nn.Module): | ||
"""The segmentation head used by contrast learning. | ||
|
||
Args: | ||
dim_in (int): | ||
The dimensions of input features. | ||
proj_dim (int, optional): | ||
The output dimensions of projection head. Default: 256. | ||
proj (str, optional): | ||
The type of projection head, | ||
only support 'linear' and 'convmlp'. Default: 'convmlp'. | ||
""" | ||
|
||
def __init__(self, in_channels: int, out_channels=19, drop_prob=0.1): | ||
super().__init__() | ||
|
||
self.seg = nn.Sequential( | ||
ConvModule(in_channels, in_channels, kernel_size=1), | ||
nn.Dropout2d(drop_prob), | ||
nn.Conv2d(in_channels, out_channels, kernel_size=1), | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
return self.seg(x) | ||
|
||
|
||
@MODELS.register_module() | ||
class ContrastHead(BaseDecodeHead): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
"""The segmentation head used by contrast learning. | ||
|
||
Args: | ||
drop_p (float): | ||
The probability of dropout in segment head. | ||
proj_n (int): | ||
Each pixel will be projected into a vector with length of proj_n. | ||
proj_mode (str): | ||
The mode for project head ,'linear' or 'convmlp'. | ||
""" | ||
|
||
def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs): | ||
super().__init__(**kwargs) | ||
if proj_n <= 0: | ||
raise KeyError('proj_n must >0') | ||
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float): | ||
raise KeyError('drop_p must be a float >=0') | ||
self.proj_n = proj_n | ||
|
||
self.seghead = SegmentationHead( | ||
in_channels=self.in_channels, | ||
out_channels=self.num_classes, | ||
drop_prob=drop_p) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might use |
||
self.projhead = ProjectionHead( | ||
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode) | ||
|
||
def forward(self, inputs): | ||
inputs = self._transform_inputs(inputs) | ||
output = [] | ||
output.append(self.seghead(inputs)) | ||
output.append(self.projhead(inputs)) | ||
|
||
return output | ||
|
||
def loss_by_feat(self, seg_logits: List, | ||
batch_data_samples: SampleList) -> dict: | ||
"""Compute segmentation loss. | ||
|
||
Args: | ||
seg_logits (List): The output from decode head forward function. | ||
seg_logits[0] is the output of seghead | ||
seg_logits[1] is the output of projhead | ||
batch_data_samples (List[:obj:`SegDataSample`]): The seg | ||
data samples. It usually includes information such | ||
as `metainfo` and `gt_sem_seg`. | ||
|
||
Returns: | ||
dict[str, Tensor]: a dictionary of loss components | ||
""" | ||
|
||
seg_label = self._stack_batch_gt(batch_data_samples) | ||
loss = dict() | ||
|
||
if self.sampler is not None: | ||
seg_weight = self.sampler.sample(seg_logits[0], seg_label) | ||
else: | ||
seg_weight = None | ||
seg_label = seg_label.squeeze(1) | ||
|
||
if not isinstance(self.loss_decode, nn.ModuleList): | ||
losses_decode = [self.loss_decode] | ||
else: | ||
losses_decode = self.loss_decode | ||
|
||
for loss_decode in losses_decode: | ||
if loss_decode.loss_name == 'loss_ce': | ||
pred = F.interpolate( | ||
input=seg_logits[0], | ||
size=seg_label.shape[-2:], | ||
mode='bilinear', | ||
align_corners=True) | ||
loss[loss_decode.loss_name] = loss_decode( | ||
pred, | ||
seg_label, | ||
weight=seg_weight, | ||
ignore_index=self.ignore_index) | ||
elif loss_decode.loss_name == 'loss_pixel_contrast_cross_entropy': | ||
loss[loss_decode.loss_name] = loss_decode( | ||
seg_logits, seg_label) | ||
else: | ||
raise KeyError('loss_name not matched') | ||
|
||
loss['acc_seg'] = accuracy( | ||
F.interpolate(seg_logits[0], seg_label.shape[1:]), | ||
seg_label, | ||
ignore_index=self.ignore_index) | ||
return loss | ||
|
||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], | ||
test_cfg: ConfigType) -> Tensor: | ||
"""Forward function for prediction. | ||
|
||
Args: | ||
inputs (Tuple[Tensor]): List of multi-level img features. | ||
batch_img_metas (dict): List Image info where each dict may also | ||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path', | ||
'ori_shape', and 'pad_shape'. | ||
For details on the values of these keys see | ||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`. | ||
test_cfg (dict): The testing config. | ||
|
||
Returns: | ||
Tensor: Outputs segmentation logits map. | ||
""" | ||
seg_logits = self.forward(inputs) | ||
|
||
return self.predict_by_feat(seg_logits, batch_img_metas) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might overwrite the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.