Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sample weighting (dataset imbalance) #2738

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions mmseg/datasets/pipelines/loading.py
Expand Up @@ -147,6 +147,8 @@ def __call__(self, results):
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
if "weight" in results["ann_info"]:
results["weight"] = results["ann_info"]["weight"]
results['gt_semantic_seg'] = gt_semantic_seg
results['seg_fields'].append('gt_semantic_seg')
return results
Expand Down
25 changes: 21 additions & 4 deletions mmseg/models/decode_heads/decode_head.py
Expand Up @@ -240,7 +240,10 @@ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
weight = None
if "weight" in img_metas[0]:
weight = torch.Tensor([meta["weight"] for meta in img_metas])
losses = self.losses(seg_logits, gt_semantic_seg, weight)
return losses

def forward_test(self, inputs, img_metas, test_cfg):
Expand Down Expand Up @@ -268,7 +271,7 @@ def cls_seg(self, feat):
return output

@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
def losses(self, seg_logit, seg_label, weight=None):
"""Compute segmentation loss."""
loss = dict()
if self.downsample_label_ratio > 0:
Expand All @@ -283,12 +286,26 @@ def losses(self, seg_logit, seg_label):
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)


# seg weight mask from sampler
seg_weight = None
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)

# item weight
if weight is not None:
# make dimensionality compatible (given )
for _ in range(seg_label.dim() - 1):
weight = weight.unsqueeze(-1)

# combine them
if weight is not None and seg_weight is not None:
seg_weight = seg_weight * weight
elif seg_weight is None:
seg_weight = weight

if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
Expand Down