New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264
base: dev-1.x
Are you sure you want to change the base?
[CodeCamp2023-527] Add pixel contrast cross entropy loss #3264
Conversation
0b363ab
to
8eb061f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @BLUE-coconut,
Thanks for your contribution! Could you move your code to the project
folder, we'll really appreciate it!
loss_decode=dict( | ||
type='PixelContrastCrossEntropyLoss', | ||
base_temperature=0.07, | ||
temperature=0.1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the reference implementation, it should add CrossEntroyloss. https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.8/configs/hrnet_w48_contrast/HRNet_W48_contrast_cityscapes_1024x512_60k.yml#L13
@@ -0,0 +1,8 @@ | |||
_base_ = [ | |||
'../_base_/models/fcn_hrcontrast18.py', '../_base_/datasets/vaihingen.py', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might test it with the Cityscapes dataset.
|
||
|
||
@MODELS.register_module() | ||
class HRNetContrastHead(BaseDecodeHead): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might rename it to ContrastHead
. Since the segmentation head could be from other methods.
output['seg'] = self.seghead(inputs) | ||
|
||
output['proj'] = self.projhead(inputs) | ||
|
||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output should be a list.
def loss_by_feat(self, seg_logits: dict, | ||
batch_data_samples: SampleList) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The seg_logits should be a list.
for loss_decode in losses_decode: | ||
if loss_decode.loss_name not in loss: | ||
loss[loss_decode.loss_name] = loss_decode( | ||
seg_logits, | ||
seg_label, | ||
weight=seg_weight, | ||
ignore_index=self.ignore_index) | ||
else: | ||
loss[loss_decode.loss_name] += loss_decode( | ||
seg_logits, | ||
seg_label, | ||
weight=seg_weight, | ||
ignore_index=self.ignore_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the implementation of PaddleSeg, the input of different loss methods should be different.
from mmseg.registry import MODELS | ||
|
||
|
||
def hard_anchor_sampling(X, y_hat, y, ignore_index, max_views, max_samples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add type hints.
loss_name='loss_pixel_contrast_cross_entropy', | ||
temperature=0.1, | ||
base_temperature=0.07, | ||
ignore_index=255, | ||
max_samples=1024, | ||
max_views=100): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add type hints.
…LUE-coconut/Add_Pixel_contrast_cross_entropy_loss
@@ -0,0 +1,37 @@ | |||
# Pixel contrast cross entropy loss | |||
|
|||
[Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf) | |
> [Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf) |
""" | ||
seg_logits = self.forward(inputs) | ||
|
||
return self.predict_by_feat(seg_logits, batch_img_metas) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might overwrite the predict_by_feat
method
self.seghead = SegmentationHead( | ||
in_channels=self.in_channels, | ||
out_channels=self.num_classes, | ||
drop_prob=drop_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might use MODEL.build()
and the SegmentationHead
head is not necessary.
|
||
|
||
@MODELS.register_module() | ||
class ContrastHead(BaseDecodeHead): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conv_seg
module in BaseDecodeHead
is not used and it might cause a distributed training error. It suggests to inherit from nn.Module
.
@@ -11,6 +11,7 @@ | |||
from mmcv.cnn import ConvModule | |||
from torch import Tensor | |||
|
|||
from mmseg.models.builder import HEADS, build_head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might remove this line.
drop_p=0.1, | ||
dropout_ratio=-1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't find the definition of norm_cfg.
drop_p=0.1, | ||
dropout_ratio=-1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_cfg=norm_cfg, | |
norm_cfg=dict(type='BN'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This parameter seems not in use.
fcn_head=dict( | ||
type='FCNHead', | ||
in_channels=2048, | ||
in_index=3, | ||
channels=512, | ||
num_convs=2, | ||
concat_input=True, | ||
dropout_ratio=0.1, | ||
num_classes=150, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=False, | ||
loss_weight=1.0))**kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decode head config might move to config file.
loss_decode=[ | ||
dict( | ||
type='PixelContrastCrossEntropyLoss', | ||
base_temperature=0.07, | ||
temperature=0.1, | ||
ignore_index=255, | ||
max_samples=1024, | ||
max_views=100, | ||
loss_weight=0.1), | ||
dict(type='CrossEntropyLoss', loss_weight=1.0) | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same to the decode head, it might move to the config file.
…LUE-coconut/Add_Pixel_contrast_cross_entropy_loss
…LUE-coconut/Add_Pixel_contrast_cross_entropy_loss
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist