Skip to content

Commit

Permalink
pre-commit modify
Browse files Browse the repository at this point in the history
  • Loading branch information
BLUE-coconut committed Aug 15, 2023
1 parent 15f31ae commit 8eb061f
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 149 deletions.
32 changes: 16 additions & 16 deletions configs/_base_/models/fcn_hrcontrast18.py
Expand Up @@ -41,23 +41,23 @@
num_blocks=(4, 4, 4, 4),
num_channels=(18, 36, 72, 144)))),
decode_head=dict(
type='HRNetContrastHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
channels=sum([18, 36, 72, 144]),

input_transform='resize_concat',
proj_n=256,
proj_mode="convmlp",
drop_p=0.1,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
type='HRNetContrastHead',
in_channels=[18, 36, 72, 144],
in_index=(0, 1, 2, 3),
channels=sum([18, 36, 72, 144]),
input_transform='resize_concat',
proj_n=256,
proj_mode='convmlp',
drop_p=0.1,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='PixelContrastCrossEntropyLoss',
base_temperature=0.07,
temperature=0.1)),

loss_decode=dict(
type='PixelContrastCrossEntropyLoss', base_temperature=0.07,temperature=0.1)),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
4 changes: 2 additions & 2 deletions mmseg/models/decode_heads/__init__.py
Expand Up @@ -14,6 +14,7 @@
from .fpn_head import FPNHead
from .gc_head import GCHead
from .ham_head import LightHamHead
from .hrnetconstrast_head import HRNetContrastHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
Expand All @@ -33,7 +34,6 @@
from .setr_up_head import SETRUPHead
from .stdc_head import STDCHead
from .uper_head import UPerHead
from .hrnetconstrast_head import HRNetContrastHead

__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
Expand All @@ -43,5 +43,5 @@
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'LightHamHead', 'PIDHead', 'DDRHead','HRNetContrastHead'
'LightHamHead', 'PIDHead', 'DDRHead', 'HRNetContrastHead'
]
110 changes: 59 additions & 51 deletions mmseg/models/decode_heads/hrnetconstrast_head.py
@@ -1,103 +1,111 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
from typing import List, Tuple

import torch.nn as nn
import torch.nn.functional as F
from typing import List,Tuple
from mmcv.cnn import ConvModule
from mmengine.device import get_device

from mmseg.registry import MODELS
from torch import Tensor

from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList
from ..utils import resize
from .decode_head import BaseDecodeHead
from ..losses import accuracy
from .decode_head import BaseDecodeHead


class ProjectionHead(nn.Module):
"""
The projection head used by contrast learning.
"""The projection head used by contrast learning.
Args:
dim_in (int): The dimensions of input features.
proj_dim (int, optional): The output dimensions of projection head. Default: 256.
proj (str, optional): The type of projection head, only support 'linear' and 'convmlp'. Default: 'convmlp'.
dim_in (int):
The dimensions of input features.
proj_dim (int, optional):
The output dimensions of projection head. Default: 256.
proj (str, optional): The type of projection head,
only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""

def __init__(self, in_channels, out_channels=256, proj='convmlp'):
super(ProjectionHead, self).__init__()
super().__init__()
if proj == 'linear':
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
ConvModule(in_channels, in_channels, kernel_size=1),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
else:
raise KeyError(
"The type of project head only support 'linear' and 'convmlp', but got {}."
.format(proj))
raise KeyError("The type of project head only support 'linear' \
and 'convmlp', but got {}.".format(proj))

def forward(self, x):
return F.normalize(self.proj(x), p=2.0, dim=1)



class SegmentationHead(nn.Module):
"""
The segmentation head used by contrast learning.
"""The segmentation head used by contrast learning.
Args:
dim_in (int): The dimensions of input features.
proj_dim (int, optional): The output dimensions of projection head. Default: 256.
proj (str, optional): The type of projection head, only support 'linear' and 'convmlp'. Default: 'convmlp'.
dim_in (int):
The dimensions of input features.
proj_dim (int, optional):
The output dimensions of projection head. Default: 256.
proj (str, optional):
The type of projection head,
only support 'linear' and 'convmlp'. Default: 'convmlp'.
"""
def __init__(self, in_channels, out_channels=19,drop_prob=0.1):
super(SegmentationHead, self).__init__()

def __init__(self, in_channels, out_channels=19, drop_prob=0.1):
super().__init__()

self.seg = nn.Sequential(
ConvModule(in_channels, in_channels, kernel_size=1),
nn.Dropout2d(drop_prob),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)

def forward(self, x):
return self.seg(x)





@MODELS.register_module()
class HRNetContrastHead(BaseDecodeHead):
"""
The segmentation head used by contrast learning.
"""The segmentation head used by contrast learning.
Args:
drop_p (float): The probability of dropout in segment head.
proj_n (int): Each pixel will be projected into a vector with length of proj_n.
proj_mode (str): The mode for project head ,'linear' or 'convmlp'.
drop_p (float):
The probability of dropout in segment head.
proj_n (int):
Each pixel will be projected into a vector with length of proj_n.
proj_mode (str):
The mode for project head ,'linear' or 'convmlp'.
"""
def __init__(self,drop_p=0.1,proj_n=256,proj_mode='convmlp',**kwargs):
super(HRNetContrastHead,self).__init__(**kwargs)

def __init__(self, drop_p=0.1, proj_n=256, proj_mode='convmlp', **kwargs):
super().__init__(**kwargs)
if proj_n <= 0:
raise KeyError("proj_n must >0")
if drop_p<0 or drop_p>1 or not isinstance(drop_p,float):
raise KeyError("drop_p must be a float >=0")
raise KeyError('proj_n must >0')
if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float):
raise KeyError('drop_p must be a float >=0')
self.proj_n = proj_n

self.seghead = SegmentationHead(in_channels=self.in_channels,out_channels=self.num_classes,drop_prob=drop_p)
self.projhead = ProjectionHead(in_channels=self.in_channels,out_channels=proj_n,proj=proj_mode)


self.seghead = SegmentationHead(
in_channels=self.in_channels,
out_channels=self.num_classes,
drop_prob=drop_p)
self.projhead = ProjectionHead(
in_channels=self.in_channels, out_channels=proj_n, proj=proj_mode)

def forward(self, inputs):
inputs = self._transform_inputs(inputs)
output = {}
output['seg'] = self.seghead(inputs)

output['proj'] = self.projhead(inputs)

return output



def loss_by_feat(self, seg_logits: dict,
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss.
Expand All @@ -114,7 +122,7 @@ def loss_by_feat(self, seg_logits: dict,

seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()

if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits['seg'], seg_label)
else:
Expand All @@ -140,7 +148,9 @@ def loss_by_feat(self, seg_logits: dict,
ignore_index=self.ignore_index)

loss['acc_seg'] = accuracy(
F.interpolate(seg_logits['seg'],seg_label.shape[1:]), seg_label, ignore_index=self.ignore_index)
F.interpolate(seg_logits['seg'], seg_label.shape[1:]),
seg_label,
ignore_index=self.ignore_index)
return loss

def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
Expand All @@ -162,5 +172,3 @@ def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
seg_logits = self.forward(inputs)['seg']

return self.predict_by_feat(seg_logits, batch_img_metas)


0 comments on commit 8eb061f

Please sign in to comment.