Skip to content

Commit

Permalink
[Feature] add bdd100K datasets (#3158)
Browse files Browse the repository at this point in the history
## Motivation
Integrate [BDD100K](https://paperswithcode.com/dataset/bdd100k) dataset.
It shares the same classes as Cityscapes, and it's commonly used for
evaluating segmentation/detection tasks in driving scenes, such as in
[RobustNet](https://arxiv.org/abs/2103.15597),
[WildNet](https://github.com/suhyeonlee/WildNet).

Enhancement for Add BDD100K Dataset #2808

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
  • Loading branch information
CastleDream and xiexinch committed Jul 14, 2023
1 parent 7254f53 commit 057155d
Show file tree
Hide file tree
Showing 33 changed files with 518 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -253,6 +253,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#isaid)
- [x] [Mapillary Vistas](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets)
- [x] [LEVIR-CD](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#levir-cd)
- [x] [BDD100K](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#bdd100K)

</details>

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Expand Up @@ -247,6 +247,7 @@ MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/dataset_prepare.md#isaid)
- [x] [Mapillary Vistas](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets)
- [x] [LEVIR-CD](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#levir-cd)
- [x] [BDD100K](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#bdd100K)

</details>

Expand Down
70 changes: 70 additions & 0 deletions configs/_base_/datasets/bdd100k.py
@@ -0,0 +1,70 @@
# dataset settings
dataset_type = 'BDD100KDataset'
data_root = 'data/bdd100k/'

crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/10k/train',
seg_map_path='labels/sem_seg/masks/train'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/10k/val',
seg_map_path='labels/sem_seg/masks/val'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
63 changes: 63 additions & 0 deletions docs/en/user_guides/2_dataset_prepare.md
Expand Up @@ -178,6 +178,26 @@ mmsegmentation
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ ├── bdd100k
│ │ ├── images
│ │ │ └── 10k
| │ │ │ ├── test
| │ │ │ ├── train
| │   │   │ └── val
│ │ └── labels
│ │ │ └── sem_seg
| │ │ │ ├── colormaps
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── masks
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── polygons
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
| │   │   │ └── rles
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
```

## Cityscapes
Expand Down Expand Up @@ -653,3 +673,46 @@ python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --o
```

The size of cropped image is 256x256, which is consistent with the original paper.

## BDD100K

- You could download BDD100k datasets from [here](https://bdd-data.berkeley.edu/) after registration.

- You can download images and masks by clicking `10K Images` button and `Segmentation` button.

- After download, unzip by the following instructions:

```bash
unzip ~/bdd100k_images_10k.zip -d ~/mmsegmentation/data/
unzip ~/bdd100k_sem_seg_labels_trainval.zip -d ~/mmsegmentation/data/
```

- And get

```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── bdd100k
│ │ ├── images
│ │ │ └── 10k
| │ │ │ ├── test
| │ │ │ ├── train
| │   │   │ └── val
│ │ └── labels
│ │ │ └── sem_seg
| │ │ │ ├── colormaps
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── masks
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── polygons
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
| │   │   │ └── rles
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
```
63 changes: 63 additions & 0 deletions docs/zh_cn/user_guides/2_dataset_prepare.md
Expand Up @@ -178,6 +178,26 @@ mmsegmentation
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ ├── bdd100k
│ │ ├── images
│ │ │ └── 10k
| │ │ │ ├── test
| │ │ │ ├── train
| │   │   │ └── val
│ │ └── labels
│ │ │ └── sem_seg
| │ │ │ ├── colormaps
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── masks
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── polygons
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
| │   │   │ └── rles
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
```

## Cityscapes
Expand Down Expand Up @@ -649,3 +669,46 @@ python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --o
```

裁剪后的影像大小为256x256,与原论文保持一致。

## BDD100K

- 可以从[官方网站](https://bdd-data.berkeley.edu/) 下载 BDD100K数据集(语义分割任务主要是10K数据集),按照官网要求注册并登陆后,数据可以在[这里](https://bdd-data.berkeley.edu/portal.html#download)找到。

- 图像数据对应的名称是是`10K Images`, 语义分割标注对应的名称是`Segmentation`

- 下载后,可以使用以下代码进行解压

```bash
unzip ~/bdd100k_images_10k.zip -d ~/mmsegmentation/data/
unzip ~/bdd100k_sem_seg_labels_trainval.zip -d ~/mmsegmentation/data/
```

就可以得到以下文件结构了:

```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── bdd100k
│ │ ├── images
│ │ │ └── 10k
| │ │ │ ├── test
| │ │ │ ├── train
| │   │   │ └── val
│ │ └── labels
│ │ │ └── sem_seg
| │ │ │ ├── colormaps
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── masks
| │ │ │ │ ├──train
| │ │ │ │ └──val
| │ │ │ ├── polygons
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
| │   │   │ └── rles
| │ │ │ │ ├──sem_seg_train.json
| │ │ │ │ └──sem_seg_val.json
```
3 changes: 2 additions & 1 deletion mmseg/datasets/__init__.py
Expand Up @@ -2,6 +2,7 @@
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseCDDataset, BaseSegDataset
from .bdd100k import BDD100KDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
Expand Down Expand Up @@ -57,5 +58,5 @@
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset'
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset'
]
30 changes: 30 additions & 0 deletions mmseg/datasets/bdd100k.py
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmseg.datasets.basesegdataset import BaseSegDataset
from mmseg.registry import DATASETS


@DATASETS.register_module()
class BDD100KDataset(BaseSegDataset):
METAINFO = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170,
30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180],
[220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])

def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
reduce_zero_label=False,
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
23 changes: 22 additions & 1 deletion mmseg/utils/class_names.py
Expand Up @@ -419,6 +419,26 @@ def lip_palette():
]


def bdd100k_classes():
"""BDD100K class names for external use(the class name is compatible with
Cityscapes )."""
return [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'
]


def bdd100k_palette():
"""bdd100k palette for external use(same with cityscapes)"""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]


dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
Expand All @@ -435,7 +455,8 @@ def lip_palette():
'stare': ['stare', 'STARE'],
'lip': ['LIP', 'lip'],
'mapillary_v1': ['mapillary_v1'],
'mapillary_v2': ['mapillary_v2']
'mapillary_v2': ['mapillary_v2'],
'bdd100k': ['bdd100k']
}


Expand Down
50 changes: 50 additions & 0 deletions projects/bdd100k_dataset/README.md
@@ -0,0 +1,50 @@
# BDD100K Dataset

Support **`BDD100K Dataset`**

## Description

Author: CastleDream

This project implements **`BDD100K Dataset`**

### Dataset preparing

Preparing `BDD100K Dataset` dataset following [BDD100K Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/main/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md#bdd100k)

```none
mmsegmentation/data
└── bdd100k
├── images
│ └── 10k
│ ├── test [2000 entries exceeds filelimit, not opening dir]
│ ├── train [7000 entries exceeds filelimit, not opening dir]
│ └── val [1000 entries exceeds filelimit, not opening dir]
└── labels
└── sem_seg
├── colormaps
│ ├── train [7000 entries exceeds filelimit, not opening dir]
│ └── val [1000 entries exceeds filelimit, not opening dir]
├── masks
│ ├── train [7000 entries exceeds filelimit, not opening dir]
│ └── val [1000 entries exceeds filelimit, not opening dir]
├── polygons
│ ├── sem_seg_train.json
│ └── sem_seg_val.json
└── rles
├── sem_seg_train.json
└── sem_seg_val.json
```

### Training commands

```bash
%cd mmsegmentation
!python tools/train.py projects/bdd100k_dataset/configs/pspnet_r50-d8_4xb2-80k_bdd100k-512x1024.py\
--work-dir your_work_dir
```

## Thanks

- [\[Datasets\] Add Mapillary Vistas Datasets to MMSeg Core Package. #2576](https://github.com/open-mmlab/mmsegmentation/pull/2576/files)
- [\[Feature\] Support CIHP dataset #1493](https://github.com/open-mmlab/mmsegmentation/pull/1493/files)

0 comments on commit 057155d

Please sign in to comment.