New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264
base: dev-1.x
Are you sure you want to change the base?
Changes from 1 commit
15f31ae
8eb061f
77661b1
0b363ab
4e0c5d4
cbbd55e
c326b40
0a7e01d
3bff55b
ea15ffa
49ba0b2
39d8da3
c14b152
d7a4bd3
1fe9521
cdc31bc
58c4881
52f266f
73d18de
2304c6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,6 +11,7 @@ | |||||
from mmcv.cnn import ConvModule | ||||||
from torch import Tensor | ||||||
|
||||||
from mmseg.models.builder import HEADS, build_head | ||||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | ||||||
from mmseg.models.losses import accuracy | ||||||
from mmseg.registry import MODELS | ||||||
|
@@ -73,7 +74,7 @@ def forward(self, x: Tensor) -> Tensor: | |||||
|
||||||
|
||||||
@MODELS.register_module() | ||||||
class ContrastHead(BaseDecodeHead): | ||||||
class ContrastHead(nn.Module): | ||||||
"""The segmentation head used by contrast learning. | ||||||
|
||||||
Args: | ||||||
|
@@ -85,21 +86,83 @@ class ContrastHead(BaseDecodeHead): | |||||
The mode for project head ,'linear' or 'convmlp'. | ||||||
""" | ||||||
|
||||||
def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs): | ||||||
def __init__(self, | ||||||
type='ContrastHead', | ||||||
in_channels=[18, 36, 72, 144], | ||||||
in_index=(0, 1, 2, 3), | ||||||
channels=sum([18, 36, 72, 144]), | ||||||
input_transform='resize_concat', | ||||||
proj_n=256, | ||||||
proj_mode='convmlp', | ||||||
drop_p=0.1, | ||||||
dropout_ratio=-1, | ||||||
num_classes=19, | ||||||
norm_cfg=norm_cfg, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't find the definition of norm_cfg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This parameter seems not in use. |
||||||
align_corners=False, | ||||||
loss_decode=[ | ||||||
dict( | ||||||
type='PixelContrastCrossEntropyLoss', | ||||||
base_temperature=0.07, | ||||||
temperature=0.1, | ||||||
ignore_index=255, | ||||||
max_samples=1024, | ||||||
max_views=100, | ||||||
loss_weight=0.1), | ||||||
dict(type='CrossEntropyLoss', loss_weight=1.0) | ||||||
], | ||||||
Comment on lines
+105
to
+115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same to the decode head, it might move to the config file. |
||||||
fcn_head=dict( | ||||||
type='FCNHead', | ||||||
in_channels=2048, | ||||||
in_index=3, | ||||||
channels=512, | ||||||
num_convs=2, | ||||||
concat_input=True, | ||||||
dropout_ratio=0.1, | ||||||
num_classes=150, | ||||||
norm_cfg=norm_cfg, | ||||||
align_corners=False, | ||||||
loss_decode=dict( | ||||||
type='CrossEntropyLoss', | ||||||
use_sigmoid=False, | ||||||
loss_weight=1.0))**kwargs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The decode head config might move to config file. |
||||||
super().__init__(**kwargs) | ||||||
if proj_n <= 0: | ||||||
raise KeyError('proj_n must >0') | ||||||
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float): | ||||||
raise KeyError('drop_p must be a float >=0') | ||||||
self.proj_n = proj_n | ||||||
|
||||||
self.seghead = SegmentationHead( | ||||||
in_channels=self.in_channels, | ||||||
out_channels=self.num_classes, | ||||||
drop_prob=drop_p) | ||||||
self.seghead = MODELS.build(fcn_head) | ||||||
self.projhead = ProjectionHead( | ||||||
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode) | ||||||
|
||||||
def _transform_inputs(self, inputs): | ||||||
"""Transform inputs for decoder. | ||||||
|
||||||
Args: | ||||||
inputs (list[Tensor]): List of multi-level img features. | ||||||
|
||||||
Returns: | ||||||
Tensor: The transformed inputs | ||||||
""" | ||||||
|
||||||
if self.input_transform == 'resize_concat': | ||||||
inputs = [inputs[i] for i in self.in_index] | ||||||
upsampled_inputs = [ | ||||||
resize( | ||||||
input=x, | ||||||
size=inputs[0].shape[2:], | ||||||
mode='bilinear', | ||||||
align_corners=self.align_corners) for x in inputs | ||||||
] | ||||||
inputs = torch.cat(upsampled_inputs, dim=1) | ||||||
elif self.input_transform == 'multiple_select': | ||||||
inputs = [inputs[i] for i in self.in_index] | ||||||
else: | ||||||
inputs = inputs[self.in_index] | ||||||
|
||||||
return inputs | ||||||
|
||||||
def forward(self, inputs): | ||||||
inputs = self._transform_inputs(inputs) | ||||||
output = [] | ||||||
|
@@ -162,6 +225,44 @@ def loss_by_feat(self, seg_logits: List, | |||||
ignore_index=self.ignore_index) | ||||||
return loss | ||||||
|
||||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, | ||||||
train_cfg: ConfigType) -> dict: | ||||||
"""Forward function for training. | ||||||
|
||||||
Args: | ||||||
inputs (Tuple[Tensor]): List of multi-level img features. | ||||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg | ||||||
data samples. It usually includes information such | ||||||
as `img_metas` or `gt_semantic_seg`. | ||||||
train_cfg (dict): The training config. | ||||||
|
||||||
Returns: | ||||||
dict[str, Tensor]: a dictionary of loss components | ||||||
""" | ||||||
seg_logits = self.forward(inputs) | ||||||
losses = self.loss_by_feat(seg_logits, batch_data_samples) | ||||||
return losses | ||||||
|
||||||
def predict_by_feat(self, seg_logits: Tensor, | ||||||
batch_img_metas: List[dict]) -> Tensor: | ||||||
"""Transform a batch of output seg_logits to the input shape. | ||||||
|
||||||
Args: | ||||||
seg_logits (Tensor): The output from decode head forward function. | ||||||
batch_img_metas (list[dict]): Meta information of each image, e.g., | ||||||
image size, scaling factor, etc. | ||||||
|
||||||
Returns: | ||||||
Tensor: Outputs segmentation logits map. | ||||||
""" | ||||||
|
||||||
seg_logits = resize( | ||||||
input=seg_logits, | ||||||
size=batch_img_metas[0]['img_shape'], | ||||||
mode='bilinear', | ||||||
align_corners=self.align_corners) | ||||||
return seg_logits | ||||||
|
||||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], | ||||||
test_cfg: ConfigType) -> Tensor: | ||||||
"""Forward function for prediction. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might remove this line.