Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264

Open
wants to merge 20 commits into
base: dev-1.x
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
113 changes: 107 additions & 6 deletions projects/pixel_contrast_cross_entropy_loss/hrnetconstrast_head.py
Expand Up @@ -11,6 +11,7 @@
from mmcv.cnn import ConvModule
from torch import Tensor

from mmseg.models.builder import HEADS, build_head
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might remove this line.

from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.registry import MODELS
Expand Down Expand Up @@ -73,7 +74,7 @@ def forward(self, x: Tensor) -> Tensor:


@MODELS.register_module()
class ContrastHead(BaseDecodeHead):
class ContrastHead(nn.Module):
"""The segmentation head used by contrast learning.

Args:
Expand All @@ -85,21 +86,83 @@ class ContrastHead(BaseDecodeHead):
The mode for project head ,'linear' or 'convmlp'.
"""

def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs):
def __init__(self,
type='ContrastHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
channels=sum([18, 36, 72, 144]),
input_transform='resize_concat',
proj_n=256,
proj_mode='convmlp',
drop_p=0.1,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't find the definition of norm_cfg.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
norm_cfg=norm_cfg,
norm_cfg=dict(type='BN'),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter seems not in use.

align_corners=False,
loss_decode=[
dict(
type='PixelContrastCrossEntropyLoss',
base_temperature=0.07,
temperature=0.1,
ignore_index=255,
max_samples=1024,
max_views=100,
loss_weight=0.1),
dict(type='CrossEntropyLoss', loss_weight=1.0)
],
Comment on lines +105 to +115
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same to the decode head, it might move to the config file.

fcn_head=dict(
type='FCNHead',
in_channels=2048,
in_index=3,
channels=512,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0))**kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decode head config might move to config file.

super().__init__(**kwargs)
if proj_n <= 0:
raise KeyError('proj_n must >0')
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float):
raise KeyError('drop_p must be a float >=0')
self.proj_n = proj_n

self.seghead = SegmentationHead(
in_channels=self.in_channels,
out_channels=self.num_classes,
drop_prob=drop_p)
self.seghead = MODELS.build(fcn_head)
self.projhead = ProjectionHead(
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode)

def _transform_inputs(self, inputs):
"""Transform inputs for decoder.

Args:
inputs (list[Tensor]): List of multi-level img features.

Returns:
Tensor: The transformed inputs
"""

if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]

return inputs

def forward(self, inputs):
inputs = self._transform_inputs(inputs)
output = []
Expand Down Expand Up @@ -162,6 +225,44 @@ def loss_by_feat(self, seg_logits: List,
ignore_index=self.ignore_index)
return loss

def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training.

Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.

Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses

def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.

Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.

Returns:
Tensor: Outputs segmentation logits map.
"""

seg_logits = resize(
input=seg_logits,
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners)
return seg_logits

def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.
Expand Down