Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264

Open
wants to merge 20 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 37 additions & 0 deletions projects/pixel_contrast_cross_entropy_loss/README.md
@@ -0,0 +1,37 @@
# Pixel contrast cross entropy loss

[Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf)
> [Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf)


## Description

This is an implementation of **pixel contrast cross entropy loss**

[Official Repo](https://github.com/tfzhou/ContrastiveSeg)

## Abstract

Current semantic segmentation methods focus only on mining “local” context, i.e., dependencies between pixels within individual images, by context-aggregation modules (e.g., dilated convolution, neural attention) or structureaware optimization criteria (e.g., IoU-like loss). However, they ignore “global” context of the training data, i.e., rich semantic relations between pixels across different images. Inspired by the recent advance in unsupervised contrastive representation learning, we propose a pixel-wise contrastive framework for semantic segmentation in the fully supervised setting. The core idea is to enforce pixel embeddings belonging to a same semantic class to be more similar than embeddings from different classes. It raises a pixel-wise metric learning paradigm for semantic segmentation, by explicitly exploring the structures of labeled pixels, which are long ignored in the field. Our method can be effortlessly incorporated into existing segmentation frameworks without extra overhead during testing.

We experimentally show that, with famous segmentation models (i.e., DeepLabV3, HRNet, OCR) and backbones (i.e., ResNet, HRNet), our method brings consistent performance improvements across diverse datasets (i.e., Cityscapes, PASCALContext, COCO-Stuff).

## Usage

Here the configs for HRNet-W18 and HRNet-W48 with pixel_contrast_cross_entropy_loss on cityscapes dataset are provided.

After putting Cityscapes dataset into "mmsegmentation/data/" dir, train the network by:

```python
python tools/train.py projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py
```

## Citation

```bibtex
@inproceedings{Wang_2021_ICCV,
author = {Wang, Wenguan and Zhou, Tianfei and Yu, Fisher and Dai, Jifeng and Konukoglu, Ender and Van Gool, Luc},
title = {Exploring Cross-Image Pixel Contrast for Semantic Segmentation},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2021},
pages = {7303-7313}
}
```
4 changes: 4 additions & 0 deletions projects/pixel_contrast_cross_entropy_loss/__init__.py
@@ -0,0 +1,4 @@
from .hrnetconstrast_head import ContrastHead
from .pixel_contrast_cross_entropy_loss import PixelContrastCrossEntropyLoss

__all__ = ['ContrastHead', 'PixelContrastCrossEntropyLoss']
@@ -0,0 +1,72 @@
# model settings

custom_imports = dict(imports=['projects.pixel_contrast_cross_entropy_loss'])
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='HRNet',
norm_cfg=norm_cfg,
norm_eval=False,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(18, 36)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(18, 36, 72)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(18, 36, 72, 144)))),
decode_head=dict(
type='ContrastHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
channels=sum([18, 36, 72, 144]),
input_transform='resize_concat',
proj_n=256,
proj_mode='convmlp',
drop_p=0.1,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=[
dict(
type='PixelContrastCrossEntropyLoss',
base_temperature=0.07,
temperature=0.1,
ignore_index=255,
max_samples=1024,
max_views=100,
loss_weight=0.1),
dict(type='CrossEntropyLoss', loss_weight=1.0)
]),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
@@ -0,0 +1,13 @@
_base_ = [
'./fcn_hrcontrast18.py', '../../../configs/_base_/datasets/cityscapes.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_40k.py'
]
data_root = 'data/cityscapes/'

train_dataloader = dict(dataset=dict(data_root=data_root))
val_dataloader = dict(dataset=dict(data_root=data_root))
test_dataloader = dict(dataset=dict(data_root=data_root))
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(data_preprocessor=data_preprocessor)
@@ -0,0 +1,13 @@
_base_ = './fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py'
model = dict(
pretrained='open-mmlab://msra/hrnetv2_w48',
backbone=dict(
extra=dict(
stage2=dict(num_channels=(48, 96)),
stage3=dict(num_channels=(48, 96, 192)),
stage4=dict(num_channels=(48, 96, 192, 384)))),
decode_head=dict(
in_channels=[48, 96, 192, 384],
channels=sum([48, 96, 192, 384]),
proj_n=720,
drop_p=0.1))
183 changes: 183 additions & 0 deletions projects/pixel_contrast_cross_entropy_loss/hrnetconstrast_head.py
@@ -0,0 +1,183 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
"""Modified from https://github.com/PaddlePaddle/PaddleSeg/
blob/2c8c35a8949fef74599f5ec557d340a14415f20d/
paddleseg/models/hrnet_contrast.py(Apache-2.0 License)"""

from typing import List, Tuple

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from torch import Tensor

from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList


class ProjectionHead(nn.Module):
"""The projection head used by contrast learning.

Args:
dim_in (int):
The dimensions of input features.
proj_dim (int, optional):
The output dimensions of projection head. Default: 256.
proj (str, optional): The type of projection head,
only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""

def __init__(self, in_channels: int, out_channels=256, proj='convmlp'):
super().__init__()
if proj == 'linear':
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
ConvModule(in_channels, in_channels, kernel_size=1),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
else:
raise KeyError("The type of project head only support 'linear' \
and 'convmlp', but got {}.".format(proj))

def forward(self, x: Tensor) -> Tensor:
return F.normalize(self.proj(x), p=2.0, dim=1)


class SegmentationHead(nn.Module):
"""The segmentation head used by contrast learning.

Args:
dim_in (int):
The dimensions of input features.
proj_dim (int, optional):
The output dimensions of projection head. Default: 256.
proj (str, optional):
The type of projection head,
only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""

def __init__(self, in_channels: int, out_channels=19, drop_prob=0.1):
super().__init__()

self.seg = nn.Sequential(
ConvModule(in_channels, in_channels, kernel_size=1),
nn.Dropout2d(drop_prob),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)

def forward(self, x: Tensor) -> Tensor:
return self.seg(x)


@MODELS.register_module()
class ContrastHead(BaseDecodeHead):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conv_seg module in BaseDecodeHead is not used and it might cause a distributed training error. It suggests to inherit from nn.Module.

"""The segmentation head used by contrast learning.

Args:
drop_p (float):
The probability of dropout in segment head.
proj_n (int):
Each pixel will be projected into a vector with length of proj_n.
proj_mode (str):
The mode for project head ,'linear' or 'convmlp'.
"""

def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs):
super().__init__(**kwargs)
if proj_n <= 0:
raise KeyError('proj_n must >0')
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float):
raise KeyError('drop_p must be a float >=0')
self.proj_n = proj_n

self.seghead = SegmentationHead(
in_channels=self.in_channels,
out_channels=self.num_classes,
drop_prob=drop_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might use MODEL.build() and the SegmentationHead head is not necessary.

self.projhead = ProjectionHead(
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode)

def forward(self, inputs):
inputs = self._transform_inputs(inputs)
output = []
output.append(self.seghead(inputs))
output.append(self.projhead(inputs))

return output

def loss_by_feat(self, seg_logits: List,
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss.

Args:
seg_logits (List): The output from decode head forward function.
seg_logits[0] is the output of seghead
seg_logits[1] is the output of projhead
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.

Returns:
dict[str, Tensor]: a dictionary of loss components
"""

seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()

if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits[0], seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)

if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode

for loss_decode in losses_decode:
if loss_decode.loss_name == 'loss_ce':
pred = F.interpolate(
input=seg_logits[0],
size=seg_label.shape[-2:],
mode='bilinear',
align_corners=True)
loss[loss_decode.loss_name] = loss_decode(
pred,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
elif loss_decode.loss_name == 'loss_pixel_contrast_cross_entropy':
loss[loss_decode.loss_name] = loss_decode(
seg_logits, seg_label)
else:
raise KeyError('loss_name not matched')

loss['acc_seg'] = accuracy(
F.interpolate(seg_logits[0], seg_label.shape[1:]),
seg_label,
ignore_index=self.ignore_index)
return loss

def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.

Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.

Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits = self.forward(inputs)

return self.predict_by_feat(seg_logits, batch_img_metas)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might overwrite the predict_by_feat method