Skip to content

Commit

Permalink
Created KITTI dataset for segmentation in autonomous driving scenario (
Browse files Browse the repository at this point in the history
…#2730)

Note that this PR is a modified version of the withdrawn PR
#1748

## Motivation

In the last years, panoptic segmentation has become more into the focus
in reseach. Weber et al.
[[Link]](http://www.cvlibs.net/publications/Weber2021NEURIPSDATA.pdf)
have published a quite nice dataset, which is in the same style like
Cityscapes, but for KITTI sequences. Since Cityscapes and KITTI-STEP
share the same classes and also a comparable domain (dashcam view),
interesting investigations, e.g. about relations in the domain e.t.c.
can be done.

Note that KITTI-STEP provices panoptic segmentation annotations which
are out of scope for mmsegmentation.

## Modification

Mostly, I added the new dataset and dataset preparation file. To
simplify the first usage of the new dataset, I also added configs for
the dataset, segformer and deeplabv3plus.

## BC-breaking (Optional)

No BC-breaking

## Use cases (Optional)

Researchers want to test their new methods, e.g. for interpretable AI in
the context of semantic segmentation. They want to show, that their
method is reproducible on comparable datasets. Thus, they can compare
Cityscapes and KITTI-STEP.

---------

Co-authored-by: CSH <40987381+csatsurnh@users.noreply.github.com>
Co-authored-by: csatsurnh <cshan1995@126.com>
Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
  • Loading branch information
4 people committed May 9, 2023
1 parent 83e7cc2 commit a85675c
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 0 deletions.
97 changes: 97 additions & 0 deletions projects/kitti_step_dataset/README.md
@@ -0,0 +1,97 @@
# KITTI STEP Dataset

Support **`KITTI STEP Dataset`**

## Description

Author: TimoK93

This project implements **`KITTI STEP Dataset`**

### Dataset preparing

After registration, the data images could be download from [KITTI-STEP](http://www.cvlibs.net/datasets/kitti/eval_step.php)

You may need to follow the following structure for dataset preparation after downloading KITTI-STEP dataset.

```
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── kitti_step
│ │ ├── testing
│ │ ├── training
│ │ ├── panoptic_maps
```

Run the preparation script to generate label files and kitti subsets by executing

```shell
python tools/convert_datasets/kitti_step.py /path/to/kitti_step
```

After executing the script, your directory should look like

```
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── kitti_step
│ │ ├── testing
│ │ ├── training
│ │ ├── panoptic_maps
│ │ ├── training_openmmlab
│ │ ├── panoptic_maps_openmmlab
```

### Training commands

```bash
# Dataset train commands
# at `mmsegmentation` folder
bash tools/dist_train.sh projects/kitti_step_dataset/configs/segformer/segformer_mit-b5_368x368_160k_kittistep.py 8
```

### Testing commands

```bash
mim test mmsegmentation projects/kitti_step_dataset/configs/segformer/segformer_mit-b5_368x368_160k_kittistep.py --work-dir work_dirs/segformer_mit-b5_368x368_160k_kittistep --checkpoint ${CHECKPOINT_PATH} --eval mIoU
```

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | model | log |
| --------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Segformer | MIT-B5 | 368x368 | 160000 | - | - | 65.05 | - | [config](configs/segformer/segformer_mit-b5_368x368_160k_kittistep.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_368x368_160k_kittistep/segformer_mit-b5_368x368_160k_kittistep_20230506_103002-20797496.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_368x368_160k_kittistep/segformer_mit-b5_368x368_160k_kittistep_20230506_103002.log.json) |

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [ ] Test-time correctness

- [x] A full README

- [ ] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
54 changes: 54 additions & 0 deletions projects/kitti_step_dataset/configs/_base_/datasets/kittistep.py
@@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'KITTISTEPDataset'
data_root = 'data/kitti_step/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (368, 368)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(1242, 375), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1242, 375),
img_ratios=[1.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='training_openmmlab/image_02/train',
ann_dir='panoptic_maps_openmmlab/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='training_openmmlab/image_02/val',
ann_dir='panoptic_maps_openmmlab/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='training_openmmlab/image_02/val',
ann_dir='panoptic_maps_openmmlab/val',
pipeline=test_pipeline))
@@ -0,0 +1,10 @@
_base_ = [
'../../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
'../_base_/datasets/kittistep.py',
'../../../../configs/_base_/default_runtime.py',
'../../../../configs/_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True),
test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513)))
@@ -0,0 +1,38 @@
_base_ = [
'../../../../configs/_base_/models/segformer_mit-b0.py',
'../_base_/datasets/kittistep.py',
'../../../../configs/_base_/default_runtime.py',
'../../../../configs/_base_/schedules/schedule_160k.py'
]

checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth' # noqa

model = dict(
backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint)),
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))

# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

data = dict(samples_per_gpu=2, workers_per_gpu=2)
@@ -0,0 +1,9 @@
_base_ = ['./segformer_mit-b0_368x368_160k_kittistep.py']

checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth' # noqa
model = dict(
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
embed_dims=64,
num_layers=[3, 6, 40, 3]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
6 changes: 6 additions & 0 deletions projects/kitti_step_dataset/mmseg/datasets/__init__.py
@@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .kitti_step import KITTISTEPDataset

__all__ = [
'KITTISTEPDataset',
]
15 changes: 15 additions & 0 deletions projects/kitti_step_dataset/mmseg/datasets/kitti_step.py
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.cityscapes import CityscapesDataset


@DATASETS.register_module()
class KITTISTEPDataset(CityscapesDataset):
"""KITTI-STEP dataset."""

def __init__(self,
img_suffix='.png',
seg_map_suffix='_labelTrainIds.png',
**kwargs):
super(KITTISTEPDataset, self).__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
77 changes: 77 additions & 0 deletions projects/kitti_step_dataset/tools/convert_datasets/kitti_step.py
@@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import shutil

import cv2
import mmcv


def kitti_to_train_ids(input):
src, gt_dir, new_gt_dir = input
label_file = src.replace('.png',
'_labelTrainIds.png').replace(gt_dir, new_gt_dir)
img = cv2.imread(src)
dirname = os.path.dirname(label_file)
os.makedirs(dirname, exist_ok=True)
sem_seg = img[:, :, 2]
cv2.imwrite(label_file, sem_seg)


def copy_file(input):
src, dst = input
if not osp.exists(dst):
os.makedirs(osp.dirname(dst), exist_ok=True)
shutil.copyfile(src, dst)


def parse_args():
parser = argparse.ArgumentParser(
description='Convert KITTI-STEP annotations to TrainIds')
parser.add_argument('kitti_path', help='kitti step data path')
parser.add_argument('--gt-dir', default='panoptic_maps', type=str)
parser.add_argument('-o', '--out-dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args


def main():
args = parse_args()
kitti_path = args.kitti_path
out_dir = args.out_dir if args.out_dir else kitti_path
mmcv.mkdir_or_exist(out_dir)

gt_dir = osp.join(kitti_path, args.gt_dir)

ann_files = []
for poly in mmcv.scandir(gt_dir, '.png', recursive=True):
poly_file = osp.join(gt_dir, poly)
ann_files.append([poly_file, args.gt_dir, args.gt_dir + '_openmmlab'])

if args.nproc > 1:
mmcv.track_parallel_progress(kitti_to_train_ids, ann_files, args.nproc)
else:
mmcv.track_progress(kitti_to_train_ids, ann_files)

copy_files = []
for f in mmcv.scandir(gt_dir, '.png', recursive=True):
original_f = osp.join(gt_dir, f).replace(args.gt_dir + '/train',
'training/image_02')
new_f = osp.join(gt_dir, f).replace(args.gt_dir,
'training_openmmlab/image_02')
original_f = original_f.replace(args.gt_dir + '/val',
'training/image_02')
new_f = new_f.replace(args.gt_dir, 'training_openmmlab/image_02')
copy_files.append([original_f, new_f])

if args.nproc > 1:
mmcv.track_parallel_progress(copy_file, copy_files, args.nproc)
else:
mmcv.track_progress(copy_file, copy_files)


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions requirements/docs.txt
Expand Up @@ -4,3 +4,4 @@ myst-parser
sphinx==4.0.2
sphinx_copybutton
sphinx_markdown_tables
urllib3<2.0.0

0 comments on commit a85675c

Please sign in to comment.