diff --git a/.circleci/test.yml b/.circleci/test.yml index 414e3c4ced..d460690065 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -61,9 +61,9 @@ jobs: command: | pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' pip install -r requirements/tests.txt -r requirements/optional.txt - run: name: Build and install @@ -97,7 +97,6 @@ jobs: command: | git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification - git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - run: name: Build Docker image command: | @@ -108,9 +107,9 @@ jobs: command: | docker exec mmseg pip install -e /mmengine docker exec mmseg pip install -U openmim - docker exec mmseg mim install 'mmcv>=2.0.0rc3' + docker exec mmseg mim install 'mmcv==2.0.0rc3' docker exec mmseg pip install -e /mmclassification - docker exec mmseg pip install -e /mmdetection + docker exec mmseg mim install 'mmdet==3.0.0rc5' docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt - run: name: Build and install @@ -149,13 +148,14 @@ workflows: name: minimum_version_cpu torch: 1.6.0 torchvision: 0.7.0 - python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images + python: "3.7" requires: - lint - build_cpu: name: maximum_version_cpu - torch: 1.13.0 - torchvision: 0.14.0 + # TODO: Fix torch 1.13 forward crush + torch: 1.12.0 + torchvision: 0.13.0 python: 3.9.0 requires: - minimum_version_cpu @@ -187,4 +187,3 @@ workflows: only: - dev-1.x - 1.x - - master diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index 65af4d2bcc..7728392481 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-18.04 strategy: matrix: - python-version: [3.6, 3.8, 3.9] + python-version: [3.8, 3.9] torch: [1.8.1] include: - torch: 1.8.1 @@ -44,9 +44,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -100,9 +100,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -166,9 +166,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -209,9 +209,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -244,9 +244,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index 66661ec8f3..df73baba8e 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -44,9 +44,9 @@ jobs: run: | pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -100,9 +100,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install @@ -135,9 +135,9 @@ jobs: python -V pip install -U openmim pip install git+https://github.com/open-mmlab/mmengine.git - mim install 'mmcv>=2.0.0rc3' + mim install 'mmcv==2.0.0rc3' pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet==3.0.0rc5' - name: Install unittest dependencies run: pip install -r requirements/tests.txt -r requirements/optional.txt - name: Build and install diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03b537683a..70952b7c9e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,8 +3,8 @@ repos: rev: 5.0.4 hooks: - id: flake8 - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + - repo: https://github.com/zhouzaida/isort + rev: 5.12.1 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf diff --git a/README.md b/README.md index f1d25cc0f6..d42be540dc 100644 --- a/README.md +++ b/README.md @@ -62,11 +62,11 @@ The 1.x branch works with **PyTorch 1.6+**. ## What's New -v1.0.0rc3 was released in 31/12/2022. +v1.0.0rc4 was released on 30/01/2023. Please refer to [changelog.md](docs/en/notes/changelog.md) for details and release history. -- Support test time augmentation ([#2184](https://github.com/open-mmlab/mmsegmentation/pull/2184)) -- Add 'Projects/' folder and the first example project ([#2412](https://github.com/open-mmlab/mmsegmentation/pull/2412)) +- Support ISNet (ICCV'2021) in projects ([#2400](https://github.com/open-mmlab/mmsegmentation/pull/2400)) +- Support HSSN (CVPR'2022) in projects ([#2444](https://github.com/open-mmlab/mmsegmentation/pull/2444)) ## Installation diff --git a/README_zh-CN.md b/README_zh-CN.md index f31f816834..bbebab5d04 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -61,7 +61,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O ## 更新日志 -最新版本 v1.0.0rc3 在 2022.12.31 发布。 +最新版本 v1.0.0rc4 在 2023.01.30 发布。 如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/notes/changelog.md)。 ## 安装 diff --git a/configs/_base_/datasets/synapse.py b/configs/_base_/datasets/synapse.py new file mode 100644 index 0000000000..86852918cd --- /dev/null +++ b/configs/_base_/datasets/synapse.py @@ -0,0 +1,41 @@ +dataset_type = 'SynapseDataset' +data_root = 'data/synapse/' +img_scale = (224, 224) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict(type='RandomRotFlip', rotate_prob=0.5, flip_prob=0.5, degree=20), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +train_dataloader = dict( + batch_size=6, + num_workers=2, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='img_dir/train', seg_map_path='ann_dir/train'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice']) +test_evaluator = val_evaluator diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile index cea2694d6f..2dddc6cdf3 100644 --- a/docker/serve/Dockerfile +++ b/docker/serve/Dockerfile @@ -4,7 +4,7 @@ ARG CUDNN="8" FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel ARG MMCV="2.0.0rc3" -ARG MMSEG="1.0.0rc3" +ARG MMSEG="1.0.0rc4" ENV PYTHONUNBUFFERED TRUE diff --git a/docs/en/advanced_guides/datasets.md b/docs/en/advanced_guides/datasets.md index 157ea3aad8..733e2a26d9 100644 --- a/docs/en/advanced_guides/datasets.md +++ b/docs/en/advanced_guides/datasets.md @@ -1 +1,386 @@ -# Datasets +# Dataset + +Dataset classes in MMSegmentation have two functions: (1) load data information after [data preparation](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) +and (2) send data into [dataset transform pipeline](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) to do [data augmentation](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). +There are 2 kinds of loaded information: (1) meta information which is original dataset information such as categories (classes) of dataset and their corresponding palette information, (2) data information which includes +the path of dataset images and labels. +The tutorial includes some main interfaces in MMSegmentation 1.x dataset class: methods of loading data information and modifying dataset classes in base dataset class, and the relationship between dataset and the data transform pipeline. + +## Main Interfaces + +Take Cityscapes as an example, if you want to run the example, please download and [preprocess](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md#cityscapes) +Cityscapes dataset in `data` directory, before running the demo code: + +Instantiate Cityscapes training dataset: + +```python +from mmseg.datasets import CityscapesDataset +from mmseg.utils import register_all_modules +register_all_modules() + +data_root = 'data/cityscapes/' +data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train') +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomCrop', crop_size=(512, 1024), cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PackSegInputs') +] + +dataset = CityscapesDataset(data_root=data_root, data_prefix=data_prefix, test_mode=False, pipeline=train_pipeline) +``` + +Get the length of training set: + +```python +print(len(dataset)) + +2975 +``` + +Get data information: The type of data information is `dict` which includes several keys: + +- `'img_path'`: path of images +- `'seg_map_path'`: path of segmentation labels +- `'seg_fields'`: saving label fields +- `'sample_idx'`: the index of the current sample + +There are also `'label_map'` and `'reduce_zero_label'` whose functions would be introduced in the next section. + +```python +# Acquire data information of first sample in dataset +print(dataset.get_data_info(0)) + +{'img_path': 'data/cityscapes/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png', + 'seg_map_path': 'data/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_labelTrainIds.png', + 'label_map': None, + 'reduce_zero_label': False, + 'seg_fields': [], + 'sample_idx': 0} +``` + +Get dataset meta information: the type of MMSegmentation meta information is also `dict`, which includes `'classes'` field for dataset classes and `'palette'` field for corresponding colors in visualization, and has `'label_map'` field and `'reduce_zero_label'` filed. + +```python +print(dataset.metainfo) + +{'classes': ('road', + 'sidewalk', + 'building', + 'wall', + 'fence', + 'pole', + 'traffic light', + 'traffic sign', + 'vegetation', + 'terrain', + 'sky', + 'person', + 'rider', + 'car', + 'truck', + 'bus', + 'train', + 'motorcycle', + 'bicycle'), + 'palette': [[128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [70, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32]], + 'label_map': None, + 'reduce_zero_label': False} +``` + +The return value of dataset `__getitem__` method is the output of data samples after data augmentation, whose type is also `dict`. It has two fields: `'inputs'` corresponding to images after data augmentation, +and `'data_samples'` corresponding to `SegDataSample`\](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/structures.md) which is new data structures in MMSegmentation 1.x, +and `gt_sem_seg` of `SegDataSample` has labels after data augmentation operations. + +```python +print(dataset[0]) + +{'inputs': tensor([[[131, 130, 130, ..., 23, 23, 23], + [132, 132, 132, ..., 23, 22, 23], + [134, 133, 133, ..., 23, 23, 23], + ..., + [ 66, 67, 67, ..., 71, 71, 71], + [ 66, 67, 66, ..., 68, 68, 68], + [ 67, 67, 66, ..., 70, 70, 70]], + + [[143, 143, 142, ..., 28, 28, 29], + [145, 145, 145, ..., 28, 28, 29], + [145, 145, 145, ..., 27, 28, 29], + ..., + [ 75, 75, 76, ..., 80, 81, 81], + [ 75, 76, 75, ..., 80, 80, 80], + [ 77, 76, 76, ..., 82, 82, 82]], + + [[126, 125, 126, ..., 21, 21, 22], + [127, 127, 128, ..., 21, 21, 22], + [127, 127, 126, ..., 21, 21, 22], + ..., + [ 63, 63, 64, ..., 69, 69, 70], + [ 64, 65, 64, ..., 69, 69, 69], + [ 65, 66, 66, ..., 72, 71, 71]]], dtype=torch.uint8), + 'data_samples': + _gt_sem_seg: + )} +``` + +## BaseSegDataset + +As mentioned above, dataset classes have the same functions, we implemented [`BaseSegDataset`](https://mmsegmentation.readthedocs.io/en/dev-1.x/api.html?highlight=BaseSegDataset#mmseg.datasets.BaseSegDataset) to reues the common functions. +It inherits [`BaseDataset` of MMEngine](https://github.com/open-mmlab/mmengine/blob/main/docs/en/advanced_tutorials/basedataset.md) and follows unified initialization process of OpenMMLab. It supports the highly effective interior storing format, some functions like +dataset concatenation and repeatedly sampling. In MMSegmentation `BaseSegDataset`, the **method of loading data information** (`load_data_list`) is redefined and adds new `get_label_map` method to **modify dataset classes information**. + +### Loading Dataset Information + +The loaded data information includes the path of images samples and annotations samples, the detailed implementation could be found in +[`load_data_list`](https://github.com/open-mmlab/mmsegmentation/blob/163277bfe0fa8fefb63ee5137917fafada1b301c/mmseg/datasets/basesegdataset.py#L231) of `BaseSegDataset` in MMSegmentation. +There are two main methods to acquire the path of images and labels: + +1. Load file paths according to the dirictory and suffix of input images and annotations + +If the dataset directory structure is organized as below, the [`load_data_list`](https://github.com/open-mmlab/mmsegmentation/blob/163277bfe0fa8fefb63ee5137917fafada1b301c/mmseg/datasets/basesegdataset.py#L231) can parse dataset directory Structure: + +``` +├── data +│ ├── my_dataset +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ │ ├── xxx{img_suffix} +│ │ │ │ ├── yyy{img_suffix} +│ │ │ ├── val +│ │ │ │ ├── zzz{img_suffix} +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ │ ├── xxx{seg_map_suffix} +│ │ │ │ ├── yyy{seg_map_suffix} +│ │ │ ├── val +│ │ │ │ ├── zzz{seg_map_suffix} +``` + +Here is an example pf ADE20K, and below the directory structure of the dataset: + +``` +├── ade +│ ├── ADEChallengeData2016 +│ │ ├── annotations +│ │ │ ├── training +│ │ │ │ ├── ADE_train_00000001.png +│ │ │ │ ├── ... +│ │ │ │── validation +│ │ │ │ ├── ADE_val_00000001.png +│ │ │ │ ├── ... +│ │ ├── images +│ │ │ ├── training +│ │ │ │ ├── ADE_train_00000001.jpg +│ │ │ │ ├── ... +│ │ │ ├── validation +│ │ │ │ ├── ADE_val_00000001.jpg +│ │ │ │ ├── ... +``` + +```python +from mmseg.datasets import ADE20KDataset + +ADE20KDataset(data_root = 'data/ade/ADEChallengeData2016', + data_prefix=dict(img_path='images/training', seg_map_path='annotations/training'), + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True) +``` + +2. Load file paths from annotation file + +Dataset also can load an annotation file which includes the data sample paths of dataset. +Take PascalContext dataset instance as an example, its input annotation file is: + +```python +2008_000008 +... +``` + +It needs to define `ann_file` when instantiation: + +```python +PascalContextDataset(data_root='data/VOCdevkit/VOC2010/', + data_prefix=dict(img_path='JPEGImages', seg_map_path='SegmentationClassContext'), + ann_file='ImageSets/SegmentationContext/train.txt') +``` + +### Modification of Dataset Classes + +- Use `metainfo` input argument + +Meta information is defined as class variables, such as `METAINFO` variable of Cityscapes: + +```python +class CityscapesDataset(BaseSegDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + METAINFO = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', + 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, + 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], + [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], + [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) + +``` + +Here `'classes'` defines class names of Cityscapes dataset annotations, if users only concern some classes about vehicles and **ignore other classes**, +the meta information of dataset could be modified by defined input argument `metainfo` when instantiating Cityscapes dataset: + +```python +from mmseg.datasets import CityscapesDataset + +data_root = 'data/cityscapes/' +data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train') +# metainfo only keep classes below: +metainfo=dict(classes=( 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle')) +dataset = CityscapesDataset(data_root=data_root, data_prefix=data_prefix, metainfo=metainfo) + +print(dataset.metainfo) + +{'classes': ('car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'), + 'palette': [[0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [70, 130, 180], + [220, 20, 60], + [255, 0, 0]], + # pixels whose label index are 255 would be ignored when calculating loss + 'label_map': {0: 255, + 1: 255, + 2: 255, + 3: 255, + 4: 255, + 5: 255, + 6: 255, + 7: 255, + 8: 255, + 9: 255, + 10: 255, + 11: 255, + 12: 255, + 13: 0, + 14: 1, + 15: 2, + 16: 3, + 17: 4, + 18: 5}, + 'reduce_zero_label': False} +``` + +Meta information is different from default setting of Cityscapes dataset. Moreover, `label_map` field is also defined, which is used for modifying label index of each pixel on segmentation mask. +The segmentation label would re-map class information by `label_map`, [here](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L151) is detailed implementation: + +```python +gt_semantic_seg_copy = gt_semantic_seg.copy() +for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id +``` + +- Using `reduce_zero_label` input argument + +To ignore label 0 (such as ADE20K dataset), we can use `reduce_zero_label` (default to `False`) argument of BaseSegDataset and its subclasses. +When `reduce_zero_label` is `True`, label 0 in segmentation annotations would be set as 255 (models of MMSegmentation would ignore label 255 in calculating loss) and indices of other labels will minus 1: + +```python +gt_semantic_seg[gt_semantic_seg == 0] = 255 +gt_semantic_seg = gt_semantic_seg - 1 +gt_semantic_seg[gt_semantic_seg == 254] = 255 +``` + +## Dataset and Data Transform Pipeline + +If the argument `pipeline` is defined, the return value of `__getitem__` method is after data argument. +If dataset input argument does not define pipeline, it is the same as return value of `get_data_info` method. + +```python +from mmseg.datasets import CityscapesDataset + +data_root = 'data/cityscapes/' +data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train') +dataset = CityscapesDataset(data_root=data_root, data_prefix=data_prefix, test_mode=False) + +print(dataset[0]) + +{'img_path': 'data/cityscapes/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png', + 'seg_map_path': 'data/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_labelTrainIds.png', + 'label_map': None, + 'reduce_zero_label': False, + 'seg_fields': [], + 'sample_idx': 0} +``` diff --git a/docs/en/api.rst b/docs/en/api.rst index 12ec13b2bd..94f64313d0 100644 --- a/docs/en/api.rst +++ b/docs/en/api.rst @@ -11,8 +11,13 @@ datasets .. automodule:: mmseg.datasets :members: -transforms +samplers ^^^^^^^^^^ +.. automodule:: mmseg.datasets.samplers + :members: + +transforms +^^^^^^^^^^^^ .. automodule:: mmseg.datasets.transforms :members: @@ -25,12 +30,12 @@ hooks :members: optimizers -^^^^^^^^^^ +^^^^^^^^^^^^^^^ .. automodule:: mmseg.engine.optimizers :members: mmseg.evaluation ------------------ +-------------- metrics ^^^^^^^^^^ @@ -40,51 +45,42 @@ metrics mmseg.models -------------- -models -^^^^^^^^^^ -.. automodule:: mmseg.models - :members: - -segmentors -^^^^^^^^^^ -.. automodule:: mmseg.models.segmentors - :members: - backbones -^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^ .. automodule:: mmseg.models.backbones :members: decode_heads -^^^^^^^^^^^^ +^^^^^^^^^^^^^^^ .. automodule:: mmseg.models.decode_heads :members: -losses +segmentors ^^^^^^^^^^ -.. automodule:: mmseg.models.losses +.. automodule:: mmseg.models.segmentors :members: -utils +losses ^^^^^^^^^^ -.. automodule:: mmseg.models.utils +.. automodule:: mmseg.models.losses :members: necks -^^^^^^^^^^ +^^^^^^^^^^^^ .. automodule:: mmseg.models.necks :members: -mmseg.registry --------------- -.. automodule:: mmseg.registry +utils +^^^^^^^^^^ +.. automodule:: mmseg.models.utils :members: + mmseg.structures ------------------ +-------------------- structures -^^^^^^^^^^ +^^^^^^^^^^^^^^^^^ .. automodule:: mmseg.structures :members: @@ -93,12 +89,12 @@ sampler .. automodule:: mmseg.structures.sampler :members: +mmseg.visualization +-------------------- +.. automodule:: mmseg.visualization + :members: + mmseg.utils -------------- .. automodule:: mmseg.utils :members: - -mmseg.visualization ----------------------- -.. automodule:: mmseg.visualization - :members: diff --git a/docs/en/migration/package.md b/docs/en/migration/package.md index ca24df5887..95fefe1310 100644 --- a/docs/en/migration/package.md +++ b/docs/en/migration/package.md @@ -96,13 +96,13 @@ OpenMMLab 2.0 defines the `BaseDataset` to function and interface of dataset, an | Packages/Modules | Changes | | :-------------------: | :------------------------------------------------------------------------------------------ | -| `mmseg.pipelines` | Renamed to `mmseg.transforms` | -| `mmseg.sampler` | Move in `mmengine.dataset.sampler` | -| `CustomDataset` | Renamed to `BaseDataset` and inherited from `BaseDataset` in MMEngine | +| `mmseg.pipelines` | Moved in `mmcv.transforms` | +| `mmseg.sampler` | Moved in `mmengine.dataset.sampler` | +| `CustomDataset` | Renamed to `BaseSegDataset` and inherited from `BaseDataset` in MMEngine | | `DefaultFormatBundle` | Replaced with `PackSegInputs` | -| `LoadImageFromFile` | Move in `mmcv.transforms.LoadImageFromFile` | +| `LoadImageFromFile` | Moved in `mmcv.transforms.LoadImageFromFile` | | `LoadAnnotations` | Moved in `mmcv.transforms.LoadAnnotations` | -| `Resize` | Moved in `mmcv.transforms` and split into `Resize`, `RandomResize` and `RandomChoiseResize` | +| `Resize` | Moved in `mmcv.transforms` and split into `Resize`, `RandomResize` and `RandomChoiceResize` | | `RandomFlip` | Moved in `mmcv.transforms.RandomFlip` | | `Pad` | Moved in `mmcv.transforms.Pad` | | `Normalize` | Moved in `mmcv.transforms.Normalize` | diff --git a/docs/en/notes/changelog.md b/docs/en/notes/changelog.md index afed3dd084..ae9e565333 100644 --- a/docs/en/notes/changelog.md +++ b/docs/en/notes/changelog.md @@ -1,5 +1,47 @@ # Changelog of v1.x +## v1.0.0rc4(01/30/2023) + +### Highlights + +- Support ISNet (ICCV'2021) in projects ([#2400](https://github.com/open-mmlab/mmsegmentation/pull/2400)) +- Support HSSN (CVPR'2022) in projects ([#2444](https://github.com/open-mmlab/mmsegmentation/pull/2444)) + +### Features + +- Add Gaussian Noise and Blur for biomedical data ([#2373](https://github.com/open-mmlab/mmsegmentation/pull/2373)) +- Add BioMedicalRandomGamma ([#2406](https://github.com/open-mmlab/mmsegmentation/pull/2406)) +- Add BioMedical3DPad ([#2383](https://github.com/open-mmlab/mmsegmentation/pull/2383)) +- Add BioMedical3DRandomFlip ([#2404](https://github.com/open-mmlab/mmsegmentation/pull/2404)) +- Add `gt_edge_map` field to SegDataSample ([#2466](https://github.com/open-mmlab/mmsegmentation/pull/2466)) +- Support synapse dataset ([#2432](https://github.com/open-mmlab/mmsegmentation/pull/2432), [#2465](https://github.com/open-mmlab/mmsegmentation/pull/2465)) +- Support Mapillary Vistas Dataset in projects ([#2484](https://github.com/open-mmlab/mmsegmentation/pull/2484)) +- Switch order of `reduce_zero_label` and applying `label_map` ([#2517](https://github.com/open-mmlab/mmsegmentation/pull/2517)) + +### Documentation + +- Add ZN Customized_runtime Doc ([#2502](https://github.com/open-mmlab/mmsegmentation/pull/2502)) +- Add EN datasets.md ([#2464](https://github.com/open-mmlab/mmsegmentation/pull/2464)) +- Fix minor typo in migration `package.md` ([#2518](https://github.com/open-mmlab/mmsegmentation/pull/2518)) + +### Bug fix + +- Fix incorrect `img_shape` value assignment in RandomCrop ([#2469](https://github.com/open-mmlab/mmsegmentation/pull/2469)) +- Fix inference api and support setting palette to SegLocalVisualizer ([#2475](https://github.com/open-mmlab/mmsegmentation/pull/2475)) +- Unfinished label conversion from `-1` to `255` ([#2516](https://github.com/open-mmlab/mmsegmentation/pull/2516)) + +### New Contributors + +- @blueyo0 made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2373 +- @Fivethousand5k made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2406 +- @suyanzhou626 made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2383 +- @unrealMJ made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2400 +- @Dominic23331 made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2432 +- @AI-Tianlong made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2444 +- @morkovka1337 made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2492 +- @Leeinsn made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2404 +- @siddancha made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2516 + ## v1.0.0rc3(31/12/2022) ### Highlights diff --git a/docs/en/notes/faq.md b/docs/en/notes/faq.md index 4903747ed2..48e97429c1 100644 --- a/docs/en/notes/faq.md +++ b/docs/en/notes/faq.md @@ -6,33 +6,35 @@ We list some common troubles faced by many users and their corresponding solutio The compatible MMSegmentation and MMCV versions are as below. Please install the correct version of MMCV to avoid installation issues. -| MMSegmentation version | MMCV version | MMClassification (optional) version | MMDetection (optional) version | -| :--------------------: | :-------------------------: | :---------------------------------: | :----------------------------: | -| 1.0.0rc3 | mmcv >= 2.0.0rc3 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4 | -| 1.0.0rc2 | mmcv >= 2.0.0rc3 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4 | -| 1.0.0rc1 | mmcv >= 2.0.0rc1 | mmcls>=1.0.0rc0 | Not required | -| 1.0.0rc0 | mmcv >= 2.0.0rc1 | mmcls>=1.0.0rc0 | Not required | -| master | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | -| 0.24.1 | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | -| 0.23.0 | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | -| 0.22.0 | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | -| 0.21.1 | mmcv-full>=1.4.4, \<=1.6.0 | Not required | Not required | -| 0.20.2 | mmcv-full>=1.3.13, \<=1.6.0 | Not required | Not required | -| 0.19.0 | mmcv-full>=1.3.13, \<1.3.17 | Not required | Not required | -| 0.18.0 | mmcv-full>=1.3.13, \<1.3.17 | Not required | Not required | -| 0.17.0 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | -| 0.16.0 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | -| 0.15.0 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | -| 0.14.1 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | -| 0.14.0 | mmcv-full>=1.3.1, \<1.3.2 | Not required | Not required | -| 0.13.0 | mmcv-full>=1.3.1, \<1.3.2 | Not required | Not required | -| 0.12.0 | mmcv-full>=1.1.4, \<1.3.2 | Not required | Not required | -| 0.11.0 | mmcv-full>=1.1.4, \<1.3.0 | Not required | Not required | -| 0.10.0 | mmcv-full>=1.1.4, \<1.3.0 | Not required | Not required | -| 0.9.0 | mmcv-full>=1.1.4, \<1.3.0 | Not required | Not required | -| 0.8.0 | mmcv-full>=1.1.4, \<1.2.0 | Not required | Not required | -| 0.7.0 | mmcv-full>=1.1.2, \<1.2.0 | Not required | Not required | -| 0.6.0 | mmcv-full>=1.1.2, \<1.2.0 | Not required | Not required | +| MMSegmentation version | MMCV version | MMClassification (optional) version | MMDetection (optional) version | +| :--------------------: | :----------------------------: | :---------------------------------: | :----------------------------: | +| 1.x/dev-1.x branch | mmcv == 2.0.0rc3 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4, \<=3.0.0rc5> | +| 1.0.0rc4 | mmcv == 2.0.0rc3 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4, \<=3.0.0rc5> | +| 1.0.0rc3 | mmcv == 2.0.0rc3 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4 \<=3.0.0rc5> | +| 1.0.0rc2 | mmcv == 2.0.0rc3 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4 \<=3.0.0rc5> | +| 1.0.0rc1 | mmcv >= 2.0.0rc1, \<=2.0.0rc3> | mmcls>=1.0.0rc0 | Not required | +| 1.0.0rc0 | mmcv >= 2.0.0rc1, \<=2.0.0rc3> | mmcls>=1.0.0rc0 | Not required | +| master | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | +| 0.24.1 | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | +| 0.23.0 | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | +| 0.22.0 | mmcv-full>=1.4.4, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 | Not required | +| 0.21.1 | mmcv-full>=1.4.4, \<=1.6.0 | Not required | Not required | +| 0.20.2 | mmcv-full>=1.3.13, \<=1.6.0 | Not required | Not required | +| 0.19.0 | mmcv-full>=1.3.13, \<1.3.17 | Not required | Not required | +| 0.18.0 | mmcv-full>=1.3.13, \<1.3.17 | Not required | Not required | +| 0.17.0 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | +| 0.16.0 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | +| 0.15.0 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | +| 0.14.1 | mmcv-full>=1.3.7, \<1.3.17 | Not required | Not required | +| 0.14.0 | mmcv-full>=1.3.1, \<1.3.2 | Not required | Not required | +| 0.13.0 | mmcv-full>=1.3.1, \<1.3.2 | Not required | Not required | +| 0.12.0 | mmcv-full>=1.1.4, \<1.3.2 | Not required | Not required | +| 0.11.0 | mmcv-full>=1.1.4, \<1.3.0 | Not required | Not required | +| 0.10.0 | mmcv-full>=1.1.4, \<1.3.0 | Not required | Not required | +| 0.9.0 | mmcv-full>=1.1.4, \<1.3.0 | Not required | Not required | +| 0.8.0 | mmcv-full>=1.1.4, \<1.2.0 | Not required | Not required | +| 0.7.0 | mmcv-full>=1.1.2, \<1.2.0 | Not required | Not required | +| 0.6.0 | mmcv-full>=1.1.2, \<1.2.0 | Not required | Not required | ## How to know the number of GPUs needed to train the model diff --git a/docs/en/user_guides/2_dataset_prepare.md b/docs/en/user_guides/2_dataset_prepare.md index a795e3bfcc..e9c7683dc0 100644 --- a/docs/en/user_guides/2_dataset_prepare.md +++ b/docs/en/user_guides/2_dataset_prepare.md @@ -138,6 +138,13 @@ mmsegmentation │ │ ├── ann_dir │ │ │ ├── train │ │ │ ├── val +│ ├── synapse +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val ``` ### Cityscapes @@ -323,7 +330,7 @@ For Potsdam dataset, please run the following command to download and re-organiz python tools/dataset_converters/potsdam.py /path/to/potsdam ``` -In our default setting, it will generate 3456 images for training and 2016 images for validation. +In our default setting, it will generate 3,456 images for training and 2,016 images for validation. ### ISPRS Vaihingen @@ -376,7 +383,7 @@ You may need to follow the following structure for dataset preparation after dow python tools/dataset_converters/isaid.py /path/to/iSAID ``` -In our default setting (`patch_width`=896, `patch_height`=896, `overlap_area`=384), it will generate 33978 images for training and 11644 images for validation. +In our default setting (`patch_width`=896, `patch_height`=896, `overlap_area`=384), it will generate 33,978 images for training and 11,644 images for validation. ## LIP(Look Into Person) dataset @@ -414,3 +421,86 @@ The contents of LIP datasets include: │   │ │ ├── 100034_483681.png │   │ │ ├── ... ``` + +## Synapse dataset + +This dataset could be download from [this page](https://www.synapse.org/#!Synapse:syn3193805/wiki/) + +To follow the data preparation setting of [TransUNet](https://arxiv.org/abs/2102.04306), which splits original training set (30 scans) +into new training (18 scans) and validation set (12 scans). Please run the following command to prepare the dataset. + +```shell +unzip RawData.zip +cd ./RawData/Training +``` + +Then create `train.txt` and `val.txt` to split dataset. + +According to TransUNet, the following is the data set division. + +train.txt + +```none +img0005.nii.gz +img0006.nii.gz +img0007.nii.gz +img0009.nii.gz +img0010.nii.gz +img0021.nii.gz +img0023.nii.gz +img0024.nii.gz +img0026.nii.gz +img0027.nii.gz +img0028.nii.gz +img0030.nii.gz +img0031.nii.gz +img0033.nii.gz +img0034.nii.gz +img0037.nii.gz +img0039.nii.gz +img0040.nii.gz +``` + +val.txt + +```none +img0008.nii.gz +img0022.nii.gz +img0038.nii.gz +img0036.nii.gz +img0032.nii.gz +img0002.nii.gz +img0029.nii.gz +img0003.nii.gz +img0001.nii.gz +img0004.nii.gz +img0025.nii.gz +img0035.nii.gz +``` + +The contents of synapse datasets include: + +```none +├── Training +│ ├── img +│ │ ├── img0001.nii.gz +│ │ ├── img0002.nii.gz +│ │ ├── ... +│ ├── label +│ │ ├── label0001.nii.gz +│ │ ├── label0002.nii.gz +│ │ ├── ... +│ ├── train.txt +│ ├── val.txt +``` + +Then, use this command to convert synapse dataset. + +```shell +python tools/dataset_converters/synapse.py --dataset-path /path/to/synapse +``` + +In our default setting, it will generate 2,211 2D images for training and 1,568 2D images for validation. + +Noted that MMSegmentation default evaluation metric (such as mean dice value) is calculated on 2D slice image, +which is not comparable to results of 3D scan in some paper such as [TransUNet](https://arxiv.org/abs/2102.04306). diff --git a/docs/zh_cn/advanced_guides/customize_runtime.md b/docs/zh_cn/advanced_guides/customize_runtime.md index 1afd95a9a6..a80aca6345 100644 --- a/docs/zh_cn/advanced_guides/customize_runtime.md +++ b/docs/zh_cn/advanced_guides/customize_runtime.md @@ -1,248 +1,162 @@ -# 自定义运行设定(待更新) +# 自定义运行设定 -## 自定义优化设定 +## 实现自定义钩子 -### 自定义 PyTorch 支持的优化器 +### Step 1: 创建一个新的钩子 -我们已经支持 PyTorch 自带的所有优化器,唯一需要修改的地方是在配置文件里的 `optimizer` 域里面。 -例如,如果您想使用 `ADAM` (注意如下操作可能会让模型表现下降),可以使用如下修改: +MMEngine 已实现了训练和测试常用的[钩子](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/hook.md), +当有定制化需求时, 可以按照如下示例实现适用于自身训练需求的钩子, 例如想修改一个超参数 `model.hyper_paramete` 的值, 让它随着训练迭代次数而变化: ```python -optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001) -``` - -为了修改模型的学习率,使用者仅需要修改配置文件里 optimizer 的 `lr` 即可。 -使用者可以参照 PyTorch 的 [API 文档](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) -直接设置参数。 - -### 自定义自己实现的优化器 - -#### 1. 定义一个新的优化器 +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence -一个自定义的优化器可以按照如下去定义: +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper -假如您想增加一个叫做 `MyOptimizer` 的优化器,它的参数分别有 `a`, `b`, 和 `c`。 -您需要创建一个叫 `mmseg/core/optimizer` 的新文件夹。 -然后再在文件,即 `mmseg/core/optimizer/my_optimizer.py` 里面去实现这个新优化器: +from mmseg.registry import HOOKS -```python -from .registry import OPTIMIZERS -from torch.optim import Optimizer +@HOOKS.register_module() +class NewHook(Hook): + """Docstring for NewHook. + """ -@OPTIMIZERS.register_module() -class MyOptimizer(Optimizer): - - def __init__(self, a, b, c) + def __init__(self, a: int, b: int) -> None: + self.a = a + self.b = b + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + cur_iter = runner.iter + # 当模型被包在 wrapper 里时获取这个模型 + if is_model_wrapper(runner.model): + model = runner.model.module + model.hyper_parameter = self.a * cur_iter + self.b ``` -#### 2. 增加优化器到注册表 (registry) +### Step 2: 导入一个新的钩子 -为了让上述定义的模块被框架发现,首先这个模块应该被导入到主命名空间 (main namespace) 里。 -有两种方式可以实现它。 +为了让上面定义的模块可以被执行的程序发现, 这个模块需要先被导入主命名空间 (main namespace) 里面, +假设 NewHook 在 `mmseg/engine/hooks/new_hook.py` 里面, 有两种方式去实现它: -- 修改 `mmseg/core/optimizer/__init__.py` 来导入它 - - 新的被定义的模块应该被导入到 `mmseg/core/optimizer/__init__.py` 这样注册表将会发现新的模块并添加它 +- 修改 `mmseg/engine/hooks/__init__.py` 来导入它. + 新定义的模块应该在 `mmseg/engine/hooks/__init__.py` 里面导入, 这样注册器可以发现并添加这个新的模块: ```python -from .my_optimizer import MyOptimizer +from .new_hook import NewHook + +__all__ = [..., NewHook] ``` -- 在配置文件里使用 `custom_imports` 去手动导入它 +- 在配置文件里使用 custom_imports 来手动导入它. ```python -custom_imports = dict(imports=['mmseg.core.optimizer.my_optimizer'], allow_failed_imports=False) +custom_imports = dict(imports=['mmseg.engine.hooks.new_hook'], allow_failed_imports=False) ``` -`mmseg.core.optimizer.my_optimizer` 模块将会在程序运行的开始被导入,并且 `MyOptimizer` 类将会自动注册。 -需要注意只有包含 `MyOptimizer` 类的包 (package) 应当被导入。 -而 `mmseg.core.optimizer.my_optimizer.MyOptimizer` **不能** 被直接导入。 - -事实上,使用者完全可以用另一个按这样导入方法的文件夹结构,只要模块的根路径已经被添加到 `PYTHONPATH` 里面。 - -#### 3. 在配置文件里定义优化器 +### Step 3: 修改配置文件 -之后您可以在配置文件的 `optimizer` 域里面使用 `MyOptimizer` -在配置文件里,优化器被定义在 `optimizer` 域里,如下所示: +可以按照如下方式, 在训练或测试中配置并使用自定义的钩子. 不同钩子在同一位点的优先级可以参考[这里](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/hook.md#%E5%86%85%E7%BD%AE%E9%92%A9%E5%AD%90), 自定义钩子如果没有指定优先, 默认是 `NORMAL`. ```python -optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +custom_hooks = [ + dict(type='NewHook', a=a_value, b=b_value, priority='ABOVE_NORMAL') +] ``` -为了使用您自己的优化器,这个域可以被改成: - -```python -optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value) -``` +## 实现自定义优化器 -### 自定义优化器的构造器 (constructor) +### Step 1: 创建一个新的优化器 -有些模型可能需要在优化器里有一些特别参数的设置,例如 批归一化层 (BatchNorm layers) 的 权重衰减 (weight decay)。 -使用者可以通过自定义优化器的构造器去微调这些细粒度参数。 +如果增加一个叫作 `MyOptimizer` 的优化器, 它有参数 `a`, `b` 和 `c`. 推荐在 `mmseg/engine/optimizers/my_optimizer.py` 文件中实现 ```python -from mmcv.utils import build_from_cfg - -from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS -from mmseg.utils import get_root_logger -from .my_optimizer import MyOptimizer - - -@OPTIMIZER_BUILDERS.register_module() -class MyOptimizerConstructor(object): +from mmseg.registry import OPTIMIZERS +from torch.optim import Optimizer - def __init__(self, optim_wrapper_cfg, paramwise_cfg=None): - def __call__(self, model): - - return my_optimizer +@OPTIMIZERS.register_module() +class MyOptimizer(Optimizer): + def __init__(self, a, b, c) ``` -默认的优化器构造器的实现可以参照 [这里](https://github.com/open-mmlab/mmcv/blob/9ecd6b0d5ff9d2172c49a182eaa669e9f27bb8e7/mmcv/runner/optimizer/default_constructor.py#L11) ,它也可以被用作新的优化器构造器的模板。 - -### 额外的设置 - -优化器没有实现的一些技巧应该通过优化器构造器 (optimizer constructor) 或者钩子 (hook) 去实现,如设置基于参数的学习率 (parameter-wise learning rates)。我们列出一些常见的设置,它们可以稳定或加速模型的训练。 -如果您有更多的设置,欢迎在 PR 和 issue 里面提交。 - -- __使用梯度截断 (gradient clip) 去稳定训练__: - - 一些模型需要梯度截断去稳定训练过程,如下所示 - - ```python - optimizer_config = dict( - _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) - ``` - - 如果您的配置继承自已经设置了 `optimizer_config` 的基础配置 (base config),您可能需要 `_delete_=True` 来重写那些不需要的设置。更多细节请参照 [配置文件文档](https://mmsegmentation.readthedocs.io/en/latest/config.html) 。 - -- __使用动量计划表 (momentum schedule) 去加速模型收敛__: - - 我们支持动量计划表去让模型基于学习率修改动量,这样可能让模型收敛地更快。 - 动量计划表经常和学习率计划表 (LR scheduler) 一起使用,例如如下配置文件就在 3D 检测里经常使用以加速收敛。 - 更多细节请参考 [CyclicLrUpdater](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/lr_updater.py#L327) 和 [CyclicMomentumUpdater](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/momentum_updater.py#L130) 的实现。 - - ```python - lr_config = dict( - policy='cyclic', - target_ratio=(10, 1e-4), - cyclic_times=1, - step_ratio_up=0.4, - ) - momentum_config = dict( - policy='cyclic', - target_ratio=(0.85 / 0.95, 1), - cyclic_times=1, - step_ratio_up=0.4, - ) - ``` - -## 自定义训练计划表 - -我们根据默认的训练迭代步数 40k/80k 来设置学习率,这在 MMCV 里叫做 [`PolyLrUpdaterHook`](https://github.com/open-mmlab/mmcv/blob/826d3a7b68596c824fa1e2cb89b6ac274f52179c/mmcv/runner/hooks/lr_updater.py#L196) 。 -我们也支持许多其他的学习率计划表:[这里](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py) ,例如 `CosineAnnealing` 和 `Poly` 计划表。下面是一些例子: - -- 步计划表 Step schedule: - - ```python - lr_config = dict(policy='step', step=[9, 10]) - ``` +### Step 2: 导入一个新的优化器 -- 余弦退火计划表 ConsineAnnealing schedule: +为了让上面定义的模块可以被执行的程序发现, 这个模块需要先被导入主命名空间 (main namespace) 里面, +假设 `MyOptimizer` 在 `mmseg/engine/optimizers/my_optimizer.py` 里面, 有两种方式去实现它: - ```python - lr_config = dict( - policy='CosineAnnealing', - warmup='linear', - warmup_iters=1000, - warmup_ratio=1.0 / 10, - min_lr_ratio=1e-5) - ``` - -## 自定义工作流 (workflow) - -工作流是一个专门定义运行顺序和轮数 (running order and epochs) 的列表 (phase, epochs)。 -默认情况下它设置成: +- 修改 `mmseg/engine/optimizers/__init__.py` 来导入它. + 新定义的模块应该在 `mmseg/engine/optimizers/__init__.py` 里面导入, 这样注册器可以发现并添加这个新的模块: ```python -workflow = [('train', 1)] +from .my_optimizer import MyOptimizer ``` -意思是训练是跑 1 个 epoch。有时候使用者可能想检查模型在验证集上的一些指标(如 损失 loss,精确性 accuracy),我们可以这样设置工作流: +- 在配置文件里使用 `custom_imports` 来手动导入它. ```python -[('train', 1), ('val', 1)] +custom_imports = dict(imports=['mmseg.engine.optimizers.my_optimizer'], allow_failed_imports=False) ``` -于是 1 个 epoch 训练,1 个 epoch 验证将交替运行。 +### Step 3: 修改配置文件 -**注意**: +随后需要修改配置文件 `optim_wrapper` 里的 `optimizer` 参数, 如果要使用你自己的优化器 `MyOptimizer`, 字段可以被修改成: -1. 模型的参数在验证的阶段不会被自动更新 -2. 配置文件里的关键词 `total_epochs` 仅控制训练的 epochs 数目,而不会影响验证时的工作流 -3. 工作流 `[('train', 1), ('val', 1)]` 和 `[('train', 1)]` 将不会改变 `EvalHook` 的行为,因为 `EvalHook` 被 `after_train_epoch` - 调用而且验证的工作流仅仅影响通过调用 `after_val_epoch` 的钩子 (hooks)。因此, `[('train', 1), ('val', 1)]` 和 `[('train', 1)]` - 的区别仅在于 runner 将在每次训练 epoch 结束后计算在验证集上的损失 +```python +optim_wrapper = dict(type='OptimWrapper', + optimizer=dict(type='MyOptimizer', + a=a_value, b=b_value, c=c_value), + clip_grad=None) +``` -## 自定义钩 (hooks) +## 实现自定义优化器封装构造器 -### 使用 MMCV 实现的钩子 (hooks) +### Step 1: 创建一个新的优化器封装构造器 -如果钩子已经在 MMCV 里被实现,如下所示,您可以直接修改配置文件来使用钩子: +构造器可以用来创建优化器, 优化器包, 以及自定义模型网络不同层的超参数. 一些模型的优化器可能会根据特定的参数而调整, 例如 BatchNorm 层的 weight decay. 使用者可以通过自定义优化器构造器来精细化设定不同参数的优化策略. ```python -custom_hooks = [ - dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL') -] -``` +from mmengine.optim import DefaultOptimWrapperConstructor +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS -### 修改默认的运行时间钩子 (runtime hooks) +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor): + def __init__(self, optim_wrapper_cfg, paramwise_cfg=None): -以下的常用的钩子没有被 `custom_hooks` 注册: + def __call__(self, model): -- log_config -- checkpoint_config -- evaluation -- lr_config -- optimizer_config -- momentum_config + return my_optimizer +``` -在这些钩子里,只有 logger hook 有 `VERY_LOW` 优先级,其他的优先级都是 `NORMAL`。 -上述提及的教程已经包括了如何修改 `optimizer_config`,`momentum_config` 和 `lr_config`。 -这里我们展示我们如何处理 `log_config`, `checkpoint_config` 和 `evaluation`。 +默认的优化器构造器在[这里](https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/default_constructor.py#L19) 被实现, 它也可以用来作为新的优化器构造器的模板. -#### 检查点配置文件 (Checkpoint config) +### Step 2: 导入一个新的优化器封装构造器 -MMCV runner 将使用 `checkpoint_config` 去初始化 [`CheckpointHook`](https://github.com/open-mmlab/mmcv/blob/9ecd6b0d5ff9d2172c49a182eaa669e9f27bb8e7/mmcv/runner/hooks/checkpoint.py#L9). +为了让上面定义的模块可以被执行的程序发现, 这个模块需要先被导入主命名空间 (main namespace) 里面, 假设 `MyOptimizerConstructor` 在 `mmseg/engine/optimizers/my_optimizer_constructor.py` 里面, 有两种方式去实现它: + +- 修改 `mmseg/engine/optimizers/__init__.py` 来导入它. + 新定义的模块应该在 `mmseg/engine/optimizers/__init__.py` 里面导入, 这样注册器可以发现并添加这个新的模块: ```python -checkpoint_config = dict(interval=1) +from .my_optimizer_constructor import MyOptimizerConstructor ``` -使用者可以设置 `max_keep_ckpts` 来仅保存一小部分检查点或者通过 `save_optimizer` 来决定是否保存优化器的状态字典 (state dict of optimizer)。 更多使用参数的细节请参考 [这里](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.CheckpointHook) 。 - -#### 日志配置文件 (Log config) - -`log_config` 包裹了许多日志钩 (logger hooks) 而且能去设置间隔 (intervals)。现在 MMCV 支持 `WandbLoggerHook`, `MlflowLoggerHook` 和 `TensorboardLoggerHook`。 -详细的使用请参照 [文档](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.LoggerHook) 。 +- 在配置文件里使用 `custom_imports` 来手动导入它. ```python -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook') - ]) +custom_imports = dict(imports=['mmseg.engine.optimizers.my_optimizer_constructor'], allow_failed_imports=False) ``` -#### 评估配置文件 (Evaluation config) +### Step 3: 修改配置文件 -`evaluation` 的配置文件将被用来初始化 [`EvalHook`](https://github.com/open-mmlab/mmsegmentation/blob/e3f6f655d69b777341aec2fe8829871cc0beadcb/mmseg/core/evaluation/eval_hooks.py#L7) 。 -除了 `interval` 键,其他的像 `metric` 这样的参数将被传递给 `dataset.evaluate()` 。 +随后需要修改配置文件 `optim_wrapper` 里的 `constructor` 参数, 如果要使用你自己的优化器封装构造器 `MyOptimizerConstructor`, 字段可以被修改成: ```python -evaluation = dict(interval=1, metric='mIoU') +optim_wrapper = dict(type='OptimWrapper', + constructor='MyOptimizerConstructor', + clip_grad=None) ``` diff --git a/docs/zh_cn/advanced_guides/datasets.md b/docs/zh_cn/advanced_guides/datasets.md index 06a75e54bd..0f3ad2b682 100644 --- a/docs/zh_cn/advanced_guides/datasets.md +++ b/docs/zh_cn/advanced_guides/datasets.md @@ -1,4 +1,4 @@ -# 数据集 +# 数据集 在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系. @@ -165,7 +165,7 @@ print(dataset[0]) ## BaseSegDataset -由于 MMSegmentation 中的所有数据集的基本功能均包括加载[预处理](https://mmsegmentation.readthedocs.io/en/dev-1.x/advanced_guides/models.html#id2) 之后的数据集的信息, 和将数据送入数据集变换流水线中, 因此在 MMSegmentation 中将其中的共同接口抽象成 [`BaseSegDataset`](https://mmsegmentation.readthedocs.io/en/dev-1.x/api.html?highlight=BaseSegDataset#mmseg.datasets.BaseSegDataset),它继承自 [MMEngine 的 `BaseDataset`](https://github.com/open-mmlab/mmengine/blob/main/docs/en/advanced_tutorials/basedataset.md), 遵循 OpenMMLab 数据集初始化统一流程, 支持高效的内部数据存储格式, 支持数据集拼接、数据集重复采样等功能. +由于 MMSegmentation 中的所有数据集的基本功能均包括(1) 加载[数据集预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/user_guides/2_dataset_prepare.md) 之后的数据信息和 (2) 将数据送入数据变换流水线中进行数据变换, 因此在 MMSegmentation 中将其中的共同接口抽象成 [`BaseSegDataset`](https://mmsegmentation.readthedocs.io/en/dev-1.x/api.html?highlight=BaseSegDataset#mmseg.datasets.BaseSegDataset),它继承自 [MMEngine 的 `BaseDataset`](https://github.com/open-mmlab/mmengine/blob/main/docs/en/advanced_tutorials/basedataset.md), 遵循 OpenMMLab 数据集初始化统一流程, 支持高效的内部数据存储格式, 支持数据集拼接、数据集重复采样等功能. 在 MMSegmentation BaseSegDataset 中重新定义了**数据信息加载方法**(`load_data_list`)和并新增了 `get_label_map` 方法用来**修改数据集的类别信息**. ### 数据信息加载 diff --git a/docs/zh_cn/api.rst b/docs/zh_cn/api.rst index be68c7579d..94f64313d0 100644 --- a/docs/zh_cn/api.rst +++ b/docs/zh_cn/api.rst @@ -11,8 +11,13 @@ datasets .. automodule:: mmseg.datasets :members: -transforms +samplers ^^^^^^^^^^ +.. automodule:: mmseg.datasets.samplers + :members: + +transforms +^^^^^^^^^^^^ .. automodule:: mmseg.datasets.transforms :members: @@ -25,7 +30,7 @@ hooks :members: optimizers -^^^^^^^^^^ +^^^^^^^^^^^^^^^ .. automodule:: mmseg.engine.optimizers :members: @@ -40,51 +45,42 @@ metrics mmseg.models -------------- -models -^^^^^^^^^^ -.. automodule:: mmseg.models - :members: - -segmentors -^^^^^^^^^^ -.. automodule:: mmseg.models.segmentors - :members: - backbones -^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^ .. automodule:: mmseg.models.backbones :members: decode_heads -^^^^^^^^^^^^ +^^^^^^^^^^^^^^^ .. automodule:: mmseg.models.decode_heads :members: -losses +segmentors ^^^^^^^^^^ -.. automodule:: mmseg.models.losses +.. automodule:: mmseg.models.segmentors :members: -utils +losses ^^^^^^^^^^ -.. automodule:: mmseg.models.utils +.. automodule:: mmseg.models.losses :members: necks -^^^^^^^^^^ +^^^^^^^^^^^^ .. automodule:: mmseg.models.necks :members: -mmseg.registry --------------- -.. automodule:: mmseg.registry +utils +^^^^^^^^^^ +.. automodule:: mmseg.models.utils :members: + mmseg.structures --------------- +-------------------- structures -^^^^^^^^^^ +^^^^^^^^^^^^^^^^^ .. automodule:: mmseg.structures :members: @@ -93,12 +89,12 @@ sampler .. automodule:: mmseg.structures.sampler :members: -mmseg.utils --------------- -.. automodule:: mmseg.utils +mmseg.visualization +-------------------- +.. automodule:: mmseg.visualization :members: -mmseg.visualization +mmseg.utils -------------- -.. automodule:: mmseg.visualization +.. automodule:: mmseg.utils :members: diff --git a/docs/zh_cn/user_guides/2_dataset_prepare.md b/docs/zh_cn/user_guides/2_dataset_prepare.md index a546b1a3d0..a8dde9211a 100644 --- a/docs/zh_cn/user_guides/2_dataset_prepare.md +++ b/docs/zh_cn/user_guides/2_dataset_prepare.md @@ -1,6 +1,6 @@ ## 准备数据集(待更新) -推荐用软链接,将数据集根目录链接到 `$MMSEGMENTATION/data` 里。如果您的文件夹结构是不同的,您也许可以试着修改配置文件里对应的路径。 +推荐用软链接, 将数据集根目录链接到 `$MMSEGMENTATION/data` 里. 如果您的文件夹结构是不同的, 您也许可以试着修改配置文件里对应的路径. ```none mmsegmentation @@ -119,51 +119,58 @@ mmsegmentation │ │ ├── ann_dir │ │ │ ├── train │ │ │ ├── val +│ ├── synapse +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val ``` ### Cityscapes -注册成功后,数据集可以在 [这里](https://www.cityscapes-dataset.com/downloads/) 下载。 +注册成功后, 数据集可以在 [这里](https://www.cityscapes-dataset.com/downloads/) 下载. -通常情况下,`**labelTrainIds.png` 被用来训练 cityscapes。 +通常情况下, `**labelTrainIds.png` 被用来训练 cityscapes. 基于 [cityscapesscripts](https://github.com/mcordts/cityscapesScripts), 我们提供了一个 [脚本](https://github.com/open-mmlab/mmsegmentation/blob/master/tools/convert_datasets/cityscapes.py), -去生成 `**labelTrainIds.png`。 +去生成 `**labelTrainIds.png`. ```shell -# --nproc 8 意味着有 8 个进程用来转换,它也可以被忽略。 +# --nproc 8 意味着有 8 个进程用来转换,它也可以被忽略. python tools/convert_datasets/cityscapes.py data/cityscapes --nproc 8 ``` ### Pascal VOC -Pascal VOC 2012 可以在 [这里](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) 下载。 -此外,许多最近在 Pascal VOC 数据集上的工作都会利用增广的数据,它们可以在 [这里](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) 找到。 +Pascal VOC 2012 可以在 [这里](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) 下载. +此外, 许多最近在 Pascal VOC 数据集上的工作都会利用增广的数据, 它们可以在 [这里](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) 找到. -如果您想使用增广后的 VOC 数据集,请运行下面的命令来将数据增广的标注转成正确的格式。 +如果您想使用增广后的 VOC 数据集, 请运行下面的命令来将数据增广的标注转成正确的格式. ```shell -# --nproc 8 意味着有 8 个进程用来转换,它也可以被忽略。 +# --nproc 8 意味着有 8 个进程用来转换,它也可以被忽略. python tools/convert_datasets/voc_aug.py data/VOCdevkit data/VOCdevkit/VOCaug --nproc 8 ``` -关于如何拼接数据集 (concatenate) 并一起训练它们,更多细节请参考 [拼接连接数据集](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/tutorials/customize_datasets.md#%E6%8B%BC%E6%8E%A5%E6%95%B0%E6%8D%AE%E9%9B%86) 。 +关于如何拼接数据集 (concatenate) 并一起训练它们, 更多细节请参考 [拼接连接数据集](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/tutorials/customize_datasets.md#%E6%8B%BC%E6%8E%A5%E6%95%B0%E6%8D%AE%E9%9B%86) . ### ADE20K -ADE20K 的训练集和验证集可以在 [这里](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip) 下载。 -您还可以在 [这里](http://data.csail.mit.edu/places/ADEchallenge/release_test.zip) 下载验证集。 +ADE20K 的训练集和验证集可以在 [这里](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip) 下载. +您还可以在 [这里](http://data.csail.mit.edu/places/ADEchallenge/release_test.zip) 下载验证集. ### Pascal Context -Pascal Context 的训练集和验证集可以在 [这里](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar) 下载。 -注册成功后,您还可以在 [这里](http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2010test.tar) 下载验证集。 +Pascal Context 的训练集和验证集可以在 [这里](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar) 下载. +注册成功后, 您还可以在 [这里](http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2010test.tar) 下载验证集. -为了从原始数据集里切分训练集和验证集, 您可以在 [这里](https://codalabuser.blob.core.windows.net/public/trainval_merged.json) -下载 trainval_merged.json。 +为了从原始数据集里切分训练集和验证集, 您可以在 [这里](https://codalabuser.blob.core.windows.net/public/trainval_merged.json) +下载 trainval_merged.json. -如果您想使用 Pascal Context 数据集, -请安装 [细节](https://github.com/zhanghang1989/detail-api) 然后再运行如下命令来把标注转换成正确的格式。 +如果您想使用 Pascal Context 数据集, +请安装 [细节](https://github.com/zhanghang1989/detail-api) 然后再运行如下命令来把标注转换成正确的格式. ```shell python tools/convert_datasets/pascal_context.py data/VOCdevkit data/VOCdevkit/VOC2010/trainval_merged.json @@ -171,64 +178,64 @@ python tools/convert_datasets/pascal_context.py data/VOCdevkit data/VOCdevkit/VO ### CHASE DB1 -CHASE DB1 的训练集和验证集可以在 [这里](https://staffnet.kingston.ac.uk/~ku15565/CHASE_DB1/assets/CHASEDB1.zip) 下载。 +CHASE DB1 的训练集和验证集可以在 [这里](https://staffnet.kingston.ac.uk/~ku15565/CHASE_DB1/assets/CHASEDB1.zip) 下载. -为了将 CHASE DB1 数据集转换成 MMSegmentation 的格式,您需要运行如下命令: +为了将 CHASE DB1 数据集转换成 MMSegmentation 的格式,您需要运行如下命令: ```shell python tools/convert_datasets/chase_db1.py /path/to/CHASEDB1.zip ``` -这个脚本将自动生成正确的文件夹结构。 +这个脚本将自动生成正确的文件夹结构. ### DRIVE -DRIVE 的训练集和验证集可以在 [这里](https://drive.grand-challenge.org/) 下载。 -在此之前,您需要注册一个账号,当前 '1st_manual' 并未被官方提供,因此需要您从其他地方获取。 +DRIVE 的训练集和验证集可以在 [这里](https://drive.grand-challenge.org/) 下载. +在此之前, 您需要注册一个账号, 当前 '1st_manual' 并未被官方提供, 因此需要您从其他地方获取. -为了将 DRIVE 数据集转换成 MMSegmentation 格式,您需要运行如下命令: +为了将 DRIVE 数据集转换成 MMSegmentation 格式, 您需要运行如下命令: ```shell python tools/convert_datasets/drive.py /path/to/training.zip /path/to/test.zip ``` -这个脚本将自动生成正确的文件夹结构。 +这个脚本将自动生成正确的文件夹结构. ### HRF -首先,下载 [healthy.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/healthy.zip) [glaucoma.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/glaucoma.zip), [diabetic_retinopathy.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/diabetic_retinopathy.zip), [healthy_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/healthy_manualsegm.zip), [glaucoma_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/glaucoma_manualsegm.zip) 以及 [diabetic_retinopathy_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/diabetic_retinopathy_manualsegm.zip) 。 +首先, 下载 [healthy.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/healthy.zip) [glaucoma.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/glaucoma.zip), [diabetic_retinopathy.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/diabetic_retinopathy.zip), [healthy_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/healthy_manualsegm.zip), [glaucoma_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/glaucoma_manualsegm.zip) 以及 [diabetic_retinopathy_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/diabetic_retinopathy_manualsegm.zip). -为了将 HRF 数据集转换成 MMSegmentation 格式,您需要运行如下命令: +为了将 HRF 数据集转换成 MMSegmentation 格式, 您需要运行如下命令: ```shell python tools/convert_datasets/hrf.py /path/to/healthy.zip /path/to/healthy_manualsegm.zip /path/to/glaucoma.zip /path/to/glaucoma_manualsegm.zip /path/to/diabetic_retinopathy.zip /path/to/diabetic_retinopathy_manualsegm.zip ``` -这个脚本将自动生成正确的文件夹结构。 +这个脚本将自动生成正确的文件夹结构. ### STARE -首先,下载 [stare-images.tar](http://cecas.clemson.edu/~ahoover/stare/probing/stare-images.tar), [labels-ah.tar](http://cecas.clemson.edu/~ahoover/stare/probing/labels-ah.tar) 和 [labels-vk.tar](http://cecas.clemson.edu/~ahoover/stare/probing/labels-vk.tar) 。 +首先, 下载 [stare-images.tar](http://cecas.clemson.edu/~ahoover/stare/probing/stare-images.tar), [labels-ah.tar](http://cecas.clemson.edu/~ahoover/stare/probing/labels-ah.tar) 和 [labels-vk.tar](http://cecas.clemson.edu/~ahoover/stare/probing/labels-vk.tar). -为了将 STARE 数据集转换成 MMSegmentation 格式,您需要运行如下命令: +为了将 STARE 数据集转换成 MMSegmentation 格式, 您需要运行如下命令: ```shell python tools/convert_datasets/stare.py /path/to/stare-images.tar /path/to/labels-ah.tar /path/to/labels-vk.tar ``` -这个脚本将自动生成正确的文件夹结构。 +这个脚本将自动生成正确的文件夹结构. ### Dark Zurich -因为我们只支持在此数据集上测试模型,所以您只需下载[验证集](https://data.vision.ee.ethz.ch/csakarid/shared/GCMA_UIoU/Dark_Zurich_val_anon.zip) 。 +因为我们只支持在此数据集上测试模型, 所以您只需下载[验证集](https://data.vision.ee.ethz.ch/csakarid/shared/GCMA_UIoU/Dark_Zurich_val_anon.zip). ### Nighttime Driving -因为我们只支持在此数据集上测试模型,所以您只需下载[测试集](http://data.vision.ee.ethz.ch/daid/NighttimeDriving/NighttimeDrivingTest.zip) 。 +因为我们只支持在此数据集上测试模型,所以您只需下载[测试集](http://data.vision.ee.ethz.ch/daid/NighttimeDriving/NighttimeDrivingTest.zip). ### LoveDA -可以从 Google Drive 里下载 [LoveDA数据集](https://drive.google.com/drive/folders/1ibYV0qwn4yuuh068Rnc-w4tPi0U0c-ti?usp=sharing) 。 +可以从 Google Drive 里下载 [LoveDA数据集](https://drive.google.com/drive/folders/1ibYV0qwn4yuuh068Rnc-w4tPi0U0c-ti?usp=sharing). 或者它还可以从 [zenodo](https://zenodo.org/record/5706578#.YZvN7SYRXdF) 下载, 您需要运行如下命令: @@ -241,46 +248,46 @@ wget https://zenodo.org/record/5706578/files/Val.zip wget https://zenodo.org/record/5706578/files/Test.zip ``` -对于 LoveDA 数据集,请运行以下命令下载并重新组织数据集 +对于 LoveDA 数据集,请运行以下命令下载并重新组织数据集: ```shell python tools/convert_datasets/loveda.py /path/to/loveDA ``` -请参照 [这里](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/inference.md) 来使用训练好的模型去预测 LoveDA 测试集并且提交到官网。 +请参照 [这里](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/inference.md) 来使用训练好的模型去预测 LoveDA 测试集并且提交到官网. -关于 LoveDA 的更多细节可以在[这里](https://github.com/Junjue-Wang/LoveDA) 找到。 +关于 LoveDA 的更多细节可以在[这里](https://github.com/Junjue-Wang/LoveDA) 找到. ### ISPRS Potsdam [Potsdam](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-potsdam/) -数据集是一个有着2D 语义分割内容标注的城市遥感数据集。 -数据集可以从挑战[主页](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/) 获得。 -需要其中的 '2_Ortho_RGB.zip' 和 '5_Labels_all_noBoundary.zip'。 +数据集是一个有着2D 语义分割内容标注的城市遥感数据集. +数据集可以从挑战[主页](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/) 获得. +需要其中的 `2_Ortho_RGB.zip` 和 `5_Labels_all_noBoundary.zip`. -对于 Potsdam 数据集,请运行以下命令下载并重新组织数据集 +对于 Potsdam 数据集,请运行以下命令下载并重新组织数据集 ```shell python tools/convert_datasets/potsdam.py /path/to/potsdam ``` -使用我们默认的配置, 将生成 3456 张图片的训练集和 2016 张图片的验证集。 +使用我们默认的配置, 将生成 3,456 张图片的训练集和 2,016 张图片的验证集. ### ISPRS Vaihingen [Vaihingen](https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-vaihingen/) -数据集是一个有着2D 语义分割内容标注的城市遥感数据集。 +数据集是一个有着2D 语义分割内容标注的城市遥感数据集. 数据集可以从挑战 [主页](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/). -需要其中的 'ISPRS_semantic_labeling_Vaihingen.zip' 和 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE.zip'。 +需要其中的 'ISPRS_semantic_labeling_Vaihingen.zip' 和 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE.zip'. -对于 Vaihingen 数据集,请运行以下命令下载并重新组织数据集 +对于 Vaihingen 数据集, 请运行以下命令下载并重新组织数据集 ```shell python tools/convert_datasets/vaihingen.py /path/to/vaihingen ``` -使用我们默认的配置 (`clip_size`=512, `stride_size`=256), 将生成 344 张图片的训练集和 398 张图片的验证集。 +使用我们默认的配置 (`clip_size`=512, `stride_size`=256), 将生成 344 张图片的训练集和 398 张图片的验证集. ### iSAID @@ -290,7 +297,7 @@ iSAID 数据集(训练集/验证集)的注释可以从 [iSAID](https://captain-w 该数据集是一个大规模的实例分割(也可以用于语义分割)的遥感数据集. -下载后,在数据集转换前,您需要将数据集文件夹调整成如下格式. +下载后, 在数据集转换前, 您需要将数据集文件夹调整成如下格式. ``` │ ├── iSAID @@ -316,4 +323,84 @@ iSAID 数据集(训练集/验证集)的注释可以从 [iSAID](https://captain-w python tools/convert_datasets/isaid.py /path/to/iSAID ``` -使用我们默认的配置 (`patch_width`=896, `patch_height`=896, `overlap_area`=384), 将生成 33978 张图片的训练集和 11644 张图片的验证集。 +使用我们默认的配置 (`patch_width`=896, `patch_height`=896, `overlap_area`=384), 将生成 33,978 张图片的训练集和 11,644 张图片的验证集. + +## Synapse dataset + +这个数据集可以在这个[网页](https://www.synapse.org/#!Synapse:syn3193805/wiki/) 里被下载. +我们参考了 [TransUNet](https://arxiv.org/abs/2102.04306) 里面的数据集预处理的设置, 它将原始数据集 (30 套 3D 样例) 切分出 18 套用于训练, 12 套用于验证. 请参考以下步骤来准备该数据集: + +```shell +unzip RawData.zip +cd ./RawData/Training +``` + +随后新建 `train.txt` 和 `val.txt`. + +根据 TransUNet 来将训练集和验证集如下划分: + +train.txt + +```none +img0005.nii.gz +img0006.nii.gz +img0007.nii.gz +img0009.nii.gz +img0010.nii.gz +img0021.nii.gz +img0023.nii.gz +img0024.nii.gz +img0026.nii.gz +img0027.nii.gz +img0028.nii.gz +img0030.nii.gz +img0031.nii.gz +img0033.nii.gz +img0034.nii.gz +img0037.nii.gz +img0039.nii.gz +img0040.nii.gz +``` + +val.txt + +```none +img0008.nii.gz +img0022.nii.gz +img0038.nii.gz +img0036.nii.gz +img0032.nii.gz +img0002.nii.gz +img0029.nii.gz +img0003.nii.gz +img0001.nii.gz +img0004.nii.gz +img0025.nii.gz +img0035.nii.gz +``` + +此时, synapse 数据集包括了以下内容: + +```none +├── Training +│ ├── img +│ │ ├── img0001.nii.gz +│ │ ├── img0002.nii.gz +│ │ ├── ... +│ ├── label +│ │ ├── label0001.nii.gz +│ │ ├── label0002.nii.gz +│ │ ├── ... +│ ├── train.txt +│ ├── val.txt +``` + +随后, 运行下面的数据集转换脚本来处理 synapse 数据集: + +```shell +python tools/dataset_converters/synapse.py --dataset-path /path/to/synapse +``` + +使用我们默认的配置, 将生成 2,211 张 2D 图片的训练集和 1,568 张图片的验证集. + +需要注意的是 MMSegmentation 默认的评价指标 (例如平均 Dice 值) 都是基于每帧 2D 图片计算的, 这与基于每套 3D 图片计算评价指标的 [TransUNet](https://arxiv.org/abs/2102.04306) 是不同的. diff --git a/mmseg/__init__.py b/mmseg/__init__.py index b395013526..59380655a2 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -8,7 +8,7 @@ from .version import __version__, version_info MMCV_MIN = '2.0.0rc3' -MMCV_MAX = '2.1.0' +MMCV_MAX = '2.0.0rc3' MMENGINE_MIN = '0.1.0' MMENGINE_MAX = '1.0.0' @@ -58,9 +58,9 @@ def digit_version(version_str: str, length: int = 4): mmcv_version = digit_version(mmcv.__version__) -assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ +assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ f'MMCV=={mmcv.__version__} is used but incompatible. ' \ - f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.' + f'Please install mmcv==2.0.0rc3.' mmengine_min_version = digit_version(MMENGINE_MIN) mmengine_max_version = digit_version(MMENGINE_MAX) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 9abc85d627..d1cc545598 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -93,8 +93,9 @@ def init_model(config: Union[str, Path, Config], def _preprare_data(imgs: ImageType, model: BaseSegmentor): cfg = model.cfg - if dict(type='LoadAnnotations') in cfg.test_pipeline: - cfg.test_pipeline.remove(dict(type='LoadAnnotations')) + for t in cfg.test_pipeline: + if t.get('type') == 'LoadAnnotations': + cfg.test_pipeline.remove(t) is_batch = True if not isinstance(imgs, (list, tuple)): diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index 58f71b62a2..8aa2e8d1a8 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -18,29 +18,35 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59 from .potsdam import PotsdamDataset from .stare import STAREDataset -from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, - GenerateEdge, LoadAnnotations, +from .synapse import SynapseDataset +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, GenerateEdge, LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray, PackSegInputs, PhotoMetricDistortion, RandomCrop, - RandomCutOut, RandomMosaic, RandomRotate, Rerange, - ResizeShortestEdge, ResizeToMultiple, RGB2Gray, - SegRescale) + RandomCutOut, RandomMosaic, RandomRotate, + RandomRotFlip, Rerange, ResizeShortestEdge, + ResizeToMultiple, RGB2Gray, SegRescale) from .voc import PascalVOCDataset # yapf: enable - __all__ = [ - 'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset', - 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', - 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', - 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', - 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', - 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'LoadAnnotations', - 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', - 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', + 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip', + 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', + 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', + 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', + 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', + 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', + 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', + 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', + 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', - 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge' + 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge', + 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', + 'SynapseDataset' ] diff --git a/mmseg/datasets/basesegdataset.py b/mmseg/datasets/basesegdataset.py index e97f8ca9d1..e7f96f7d2c 100644 --- a/mmseg/datasets/basesegdataset.py +++ b/mmseg/datasets/basesegdataset.py @@ -47,7 +47,7 @@ class BaseSegDataset(BaseDataset): data_root (str, optional): The root directory for ``data_prefix`` and ``ann_file``. Defaults to None. data_prefix (dict, optional): Prefix for training data. Defaults to - dict(img_path=None, seg_path=None). + dict(img_path=None, seg_map_path=None). img_suffix (str): Suffix of images. Default: '.jpg' seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' filter_cfg (dict, optional): Config for filter data. Defaults to None. @@ -220,7 +220,7 @@ def _update_palette(self) -> list: # return subset of palette for old_id, new_id in sorted( self.label_map.items(), key=lambda x: x[1]): - if new_id != -1: + if new_id != 255: new_palette.append(palette[old_id]) new_palette = type(palette)(new_palette) else: diff --git a/mmseg/datasets/synapse.py b/mmseg/datasets/synapse.py new file mode 100644 index 0000000000..6f83b64150 --- /dev/null +++ b/mmseg/datasets/synapse.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class SynapseDataset(BaseSegDataset): + """Synapse dataset. + + Before dataset preprocess of Synapse, there are total 13 categories of + foreground which does not include background. After preprocessing, 8 + foreground categories are kept while the other 5 foreground categories are + handled as background. The ``img_suffix`` is fixed to '.jpg' and + ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=('background', 'aorta', 'gallbladder', 'left_kidney', + 'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'), + palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], + [0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255], + [240, 240, 240]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index 7f67acec02..25f4ee4a98 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -3,17 +3,24 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray) -from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, - GenerateEdge, PhotoMetricDistortion, RandomCrop, - RandomCutOut, RandomMosaic, RandomRotate, Rerange, +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, GenerateEdge, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomMosaic, RandomRotate, RandomRotFlip, Rerange, ResizeShortestEdge, ResizeToMultiple, RGB2Gray, SegRescale) +# yapf: enable __all__ = [ 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', - 'ResizeShortestEdge' + 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', + 'RandomRotFlip' ] diff --git a/mmseg/datasets/transforms/formatting.py b/mmseg/datasets/transforms/formatting.py index bb4db4484e..f4018f788f 100644 --- a/mmseg/datasets/transforms/formatting.py +++ b/mmseg/datasets/transforms/formatting.py @@ -73,6 +73,12 @@ def transform(self, results: dict) -> dict: ...].astype(np.int64))) data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + if 'gt_edge_map' in results: + gt_edge_data = dict( + data=to_tensor(results['gt_edge_map'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data))) + img_meta = {} for key in self.meta_keys: if key in results: diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py index ea51e0df59..65c0dfec47 100644 --- a/mmseg/datasets/transforms/loading.py +++ b/mmseg/datasets/transforms/loading.py @@ -96,14 +96,6 @@ def _load_seg_map(self, results: dict) -> None: img_bytes, flag='unchanged', backend=self.imdecode_backend).squeeze().astype(np.uint8) - # modify if custom classes - if results.get('label_map', None) is not None: - # Add deep copy to solve bug of repeatedly - # replace `gt_semantic_seg`, which is reported in - # https://github.com/open-mmlab/mmsegmentation/pull/1445/ - gt_semantic_seg_copy = gt_semantic_seg.copy() - for old_id, new_id in results['label_map'].items(): - gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id # reduce zero_label if self.reduce_zero_label is None: self.reduce_zero_label = results['reduce_zero_label'] @@ -116,6 +108,14 @@ def _load_seg_map(self, results: dict) -> None: gt_semantic_seg[gt_semantic_seg == 0] = 255 gt_semantic_seg = gt_semantic_seg - 1 gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id results['gt_seg_map'] = gt_semantic_seg results['seg_fields'].append('gt_seg_map') diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 5d1173f254..ef4e78dd8c 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import warnings -from typing import Dict, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import cv2 import mmcv @@ -10,6 +10,7 @@ from mmcv.transforms.utils import cache_randomness from mmengine.utils import is_tuple_of from numpy import random +from scipy.ndimage import gaussian_filter from mmseg.datasets.dataset_wrappers import MultiImageMixDataset from mmseg.registry import TRANSFORMS @@ -313,9 +314,9 @@ def transform(self, results: dict) -> dict: # crop semantic seg for key in results.get('seg_fields', []): results[key] = self.crop(results[key], crop_bbox) - img_shape = img.shape + results['img'] = img - results['img_shape'] = img_shape + results['img_shape'] = img.shape[:2] return results def __repr__(self): @@ -860,6 +861,84 @@ def __repr__(self): return repr_str +@TRANSFORMS.register_module() +class RandomRotFlip(BaseTransform): + """Rotate and flip the image & seg or just rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + rotate_prob (float): The probability of rotate image. + flip_prob (float): The probability of rotate&flip image. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + """ + + def __init__(self, rotate_prob=0.5, flip_prob=0.5, degree=(-20, 20)): + self.rotate_prob = rotate_prob + self.flip_prob = flip_prob + assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + + def random_rot_flip(self, results: dict) -> dict: + k = np.random.randint(0, 4) + results['img'] = np.rot90(results['img'], k) + for key in results.get('seg_fields', []): + results[key] = np.rot90(results[key], k) + axis = np.random.randint(0, 2) + results['img'] = np.flip(results['img'], axis=axis).copy() + for key in results.get('seg_fields', []): + results[key] = np.flip(results[key], axis=axis).copy() + return results + + def random_rotate(self, results: dict) -> dict: + angle = np.random.uniform(min(*self.degree), max(*self.degree)) + results['img'] = mmcv.imrotate(results['img'], angle=angle) + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate(results[key], angle=angle) + return results + + def transform(self, results: dict) -> dict: + """Call function to rotate or rotate & flip image, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated or rotated & flipped results. + """ + rotate_flag = 0 + if random.random() < self.rotate_prob: + results = self.random_rotate(results) + rotate_flag = 1 + if random.random() < self.flip_prob and rotate_flag == 0: + results = self.random_rot_flip(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(rotate_prob={self.rotate_prob}, ' \ + f'flip_prob={self.flip_prob}, ' \ + f'degree={self.degree})' + return repr_str + + @TRANSFORMS.register_module() class RandomMosaic(BaseTransform): """Mosaic augmentation. Given 4 images, mosaic transform combines them into @@ -1166,7 +1245,7 @@ class GenerateEdge(BaseTransform): - gt_seg_map Added Keys: - - gt_edge (np.ndarray, uint8): The edge annotation generated from the + - gt_edge_map (np.ndarray, uint8): The edge annotation generated from the seg map by extracting border between different semantics. Args: @@ -1217,7 +1296,7 @@ def transform(self, results: Dict) -> Dict: (self.edge_width, self.edge_width)) edge = cv2.dilate(edge, kernel) - results['gt_edge'] = edge + results['gt_edge_map'] = edge results['edge_width'] = self.edge_width return results @@ -1507,3 +1586,551 @@ def transform(self, results: dict) -> dict: def __repr__(self): return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' + + +@TRANSFORMS.register_module() +class BioMedicalGaussianNoise(BaseTransform): + """Add random Gaussian noise to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + prob (float): Probability to add Gaussian noise for + each sample. Default to 0.1. + mean (float): Mean or “centre” of the distribution. Default to 0.0. + std (float): Standard deviation of distribution. Default to 0.1. + """ + + def __init__(self, + prob: float = 0.1, + mean: float = 0.0, + std: float = 0.1) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 and std >= 0.0 + self.prob = prob + self.mean = mean + self.std = std + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian noise to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + rand_std = np.random.uniform(0, self.std) + noise = np.random.normal( + self.mean, rand_std, size=results['img'].shape) + # noise is float64 array, convert to the results['img'].dtype + noise = noise.astype(results['img'].dtype) + results['img'] = results['img'] + noise + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'mean={self.mean}, ' + repr_str += f'std={self.std})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalGaussianBlur(BaseTransform): + """Add Gaussian blur with random sigma to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + sigma_range (Tuple[float, float]|float): range to randomly + select sigma value. Default to (0.5, 1.0). + prob (float): Probability to apply Gaussian blur + for each sample. Default to 0.2. + prob_per_channel (float): Probability to apply Gaussian blur + for each channel (axis N of the image). Default to 0.5. + different_sigma_per_channel (bool): whether to use different + sigma for each channel (axis N of the image). Default to True. + different_sigma_per_axis (bool): whether to use different + sigma for axis Z, X and Y of the image. Default to True. + """ + + def __init__(self, + sigma_range: Tuple[float, float] = (0.5, 1.0), + prob: float = 0.2, + prob_per_channel: float = 0.5, + different_sigma_per_channel: bool = True, + different_sigma_per_axis: bool = True) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 + assert 0.0 <= prob_per_channel <= 1.0 + assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2 + self.sigma_range = sigma_range + self.prob = prob + self.prob_per_channel = prob_per_channel + self.different_sigma_per_channel = different_sigma_per_channel + self.different_sigma_per_axis = different_sigma_per_axis + + def _get_valid_sigma(self, value_range) -> Tuple[float, ...]: + """Ensure the `value_range` to be either a single value or a sequence + of two values. If the `value_range` is a sequence, generate a random + value with `[value_range[0], value_range[1]]` based on uniform + sampling. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501 + + Args: + value_range (tuple|list|float|int): the input value range + """ + if (isinstance(value_range, (list, tuple))): + if (value_range[0] == value_range[1]): + value = value_range[0] + else: + orig_type = type(value_range[0]) + value = np.random.uniform(value_range[0], value_range[1]) + value = orig_type(value) + return value + + def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray: + """Random generate sigma and apply Gaussian Blur to the data + Args: + data_sample (np.ndarray): data sample with multiple modalities, + the data shape is (N, Z, Y, X) + """ + sigma = None + for c in range(data_sample.shape[0]): + if np.random.rand() < self.prob_per_channel: + # if no `sigma` is generated, generate one + # if `self.different_sigma_per_channel` is True, + # re-generate random sigma for each channel + if (sigma is None or self.different_sigma_per_channel): + if (not self.different_sigma_per_axis): + sigma = self._get_valid_sigma(self.sigma_range) + else: + sigma = [ + self._get_valid_sigma(self.sigma_range) + for _ in data_sample.shape[1:] + ] + # apply gaussian filter with `sigma` + data_sample[c] = gaussian_filter( + data_sample[c], sigma, order=0) + return data_sample + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian blur to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + results['img'] = self._gaussian_blur(results['img']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'prob_per_channel={self.prob_per_channel}, ' + repr_str += f'sigma_range={self.sigma_range}, ' + repr_str += 'different_sigma_per_channel='\ + f'{self.different_sigma_per_channel}, ' + repr_str += 'different_sigma_per_axis='\ + f'{self.different_sigma_per_axis})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalRandomGamma(BaseTransform): + """Using random gamma correction to process the biomedical image. + + Modified from + https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501 + With licence: Apache 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + - img + + Args: + prob (float): The probability to perform this transform. Default: 0.5. + gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2). + invert_image (bool): Whether invert the image before applying gamma + augmentation. Default: False. + per_channel (bool): Whether perform the transform each channel + individually. Default: False + retain_stats (bool): Gamma transformation will alter the mean and std + of the data in the patch. If retain_stats=True, the data will be + transformed to match the mean and standard deviation before gamma + augmentation. Default: False. + """ + + def __init__(self, + prob: float = 0.5, + gamma_range: Tuple[float] = (0.5, 2), + invert_image: bool = False, + per_channel: bool = False, + retain_stats: bool = False): + assert 0 <= prob and prob <= 1 + assert isinstance(gamma_range, tuple) and len(gamma_range) == 2 + assert isinstance(invert_image, bool) + assert isinstance(per_channel, bool) + assert isinstance(retain_stats, bool) + self.prob = prob + self.gamma_range = gamma_range + self.invert_image = invert_image + self.per_channel = per_channel + self.retain_stats = retain_stats + + @cache_randomness + def _do_gamma(self): + """Whether do adjust gamma for image.""" + return np.random.rand() < self.prob + + def _adjust_gamma(self, img: np.array): + """Gamma adjustment for image. + + Args: + img (np.array): Input image before gamma adjust. + + Returns: + np.arrays: Image after gamma adjust. + """ + + if self.invert_image: + img = -img + + def _do_adjust(img): + if retain_stats_here: + img_mean = img.mean() + img_std = img.std() + if np.random.random() < 0.5 and self.gamma_range[0] < 1: + gamma = np.random.uniform(self.gamma_range[0], 1) + else: + gamma = np.random.uniform( + max(self.gamma_range[0], 1), self.gamma_range[1]) + img_min = img.min() + img_range = img.max() - img_min # range + img = np.power(((img - img_min) / float(img_range + 1e-7)), + gamma) * img_range + img_min + if retain_stats_here: + img = img - img.mean() + img = img / (img.std() + 1e-8) * img_std + img = img + img_mean + return img + + if not self.per_channel: + retain_stats_here = self.retain_stats + img = _do_adjust(img) + else: + for c in range(img.shape[0]): + img[c] = _do_adjust(img[c]) + if self.invert_image: + img = -img + return img + + def transform(self, results: dict) -> dict: + """Call function to perform random gamma correction + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with random gamma correction performed. + """ + do_gamma = self._do_gamma() + + if do_gamma: + results['img'] = self._adjust_gamma(results['img']) + else: + pass + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'gamma_range={self.gamma_range},' + repr_str += f'invert_image={self.invert_image},' + repr_str += f'per_channel={self.per_channel},' + repr_str += f'retain_stats={self.retain_stats}' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DPad(BaseTransform): + """Pad the biomedical 3d image & biomedical 3d semantic segmentation maps. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - pad_shape (Tuple[int, int, int]): The padded shape. + + Args: + pad_shape (Tuple[int, int, int]): Fixed padding size. + Expected padding shape (Z, Y, X). + pad_val (float): Padding value for biomedical image. + The padding mode is set to "constant". The value + to be filled in padding area. Default: 0. + seg_pad_val (int): Padding value for biomedical 3d semantic + segmentation maps. The padding mode is set to "constant". + The value to be filled in padding area. Default: 0. + """ + + def __init__(self, + pad_shape: Tuple[int, int, int], + pad_val: float = 0., + seg_pad_val: int = 0) -> None: + + # check pad_shape + assert pad_shape is not None + if not isinstance(pad_shape, tuple): + assert len(pad_shape) == 3 + + self.pad_shape = pad_shape + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def _pad_img(self, results: dict) -> None: + """Pad images according to ``self.pad_shape`` + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: The dict contains the padded image and shape + information. + """ + padded_img = self._to_pad( + results['img'], pad_shape=self.pad_shape, pad_val=self.pad_val) + + results['img'] = padded_img + results['pad_shape'] = padded_img.shape[1:] + + def _pad_seg(self, results: dict) -> None: + """Pad semantic segmentation map according to ``self.pad_shape`` if + ``gt_seg_map`` is not None in results dict. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Update the padded gt seg map in dict. + """ + if results.get('gt_seg_map', None) is not None: + pad_gt_seg = self._to_pad( + results['gt_seg_map'][None, ...], + pad_shape=results['pad_shape'], + pad_val=self.seg_pad_val) + results['gt_seg_map'] = pad_gt_seg[1:] + + @staticmethod + def _to_pad(img: np.ndarray, + pad_shape: Tuple[int, int, int], + pad_val: Union[int, float] = 0) -> np.ndarray: + """Pad the given 3d image to a certain shape with specified padding + value. + + Args: + img (ndarray): Biomedical image with shape (N, Z, Y, X) + to be padded. N is the number of modalities. + pad_shape (Tuple[int,int,int]): Expected padding shape (Z, Y, X). + pad_val (float, int): Values to be filled in padding areas + and the padding_mode is set to 'constant'. Default: 0. + + Returns: + ndarray: The padded image. + """ + # compute pad width + d = max(pad_shape[0] - img.shape[1], 0) + pad_d = (d // 2, d - d // 2) + h = max(pad_shape[1] - img.shape[2], 0) + pad_h = (h // 2, h - h // 2) + w = max(pad_shape[2] - img.shape[2], 0) + pad_w = (w // 2, w - w // 2) + + pad_list = [(0, 0), pad_d, pad_h, pad_w] + + img = np.pad(img, pad_list, mode='constant', constant_values=pad_val) + return img + + def transform(self, results: dict) -> dict: + """Call function to pad images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_seg(results) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'pad_shape={self.pad_shape}, ' + repr_str += f'pad_val={self.pad_val}), ' + repr_str += f'seg_pad_val={self.seg_pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DRandomFlip(BaseTransform): + """Flip biomedical 3D images and segmentations. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/spatial_transforms.py # noqa:E501 + + Copyright 2021 Division of + Medical Image Computing, German Cancer Research Center (DKFZ) and Applied + Computer Vision Lab, Helmholtz Imaging Platform. + Licensed under the Apache-2.0 License. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - do_flip + - flip_axes + + Args: + prob (float): Flipping probability. + axes (Tuple[int, ...]): Flipping axes with order 'ZXY'. + swap_label_pairs (Optional[List[Tuple[int, int]]]): + The segmentation label pairs that are swapped when flipping. + """ + + def __init__(self, + prob: float, + axes: Tuple[int, ...], + swap_label_pairs: Optional[List[Tuple[int, int]]] = None): + self.prob = prob + self.axes = axes + self.swap_label_pairs = swap_label_pairs + assert prob >= 0 and prob <= 1 + if axes is not None: + assert max(axes) <= 2 + + @staticmethod + def _flip(img, direction: Tuple[bool, bool, bool]) -> np.ndarray: + if direction[0]: + img[:, :] = img[:, ::-1] + if direction[1]: + img[:, :, :] = img[:, :, ::-1] + if direction[2]: + img[:, :, :, :] = img[:, :, :, ::-1] + return img + + def _do_flip(self, img: np.ndarray) -> Tuple[bool, bool, bool]: + """Call function to determine which axis to flip. + + Args: + img (np.ndarry): Image or segmentation map array. + Returns: + tuple: Flip action, whether to flip on the z, x, and y axes. + """ + flip_c, flip_x, flip_y = False, False, False + if self.axes is not None: + flip_c = 0 in self.axes and np.random.rand() < self.prob + flip_x = 1 in self.axes and np.random.rand() < self.prob + if len(img.shape) == 4: + flip_y = 2 in self.axes and np.random.rand() < self.prob + return flip_c, flip_x, flip_y + + def _swap_label(self, seg: np.ndarray) -> np.ndarray: + out = seg.copy() + for first, second in self.swap_label_pairs: + first_area = (seg == first) + second_area = (seg == second) + out[first_area] = second + out[second_area] = first + return out + + def transform(self, results: Dict) -> Dict: + """Call function to flip and swap pair labels. + + Args: + results (dict): Result dict. + Returns: + dict: Flipped results, 'do_flip', 'flip_axes' keys are added into + result dict. + """ + # get actual flipped axis + if 'do_flip' not in results: + results['do_flip'] = self._do_flip(results['img']) + if 'flip_axes' not in results: + results['flip_axes'] = self.axes + # flip image + results['img'] = self._flip( + results['img'], direction=results['do_flip']) + # flip seg + if results['gt_seg_map'] is not None: + if results['gt_seg_map'].shape != results['img'].shape: + results['gt_seg_map'] = results['gt_seg_map'][None, :] + results['gt_seg_map'] = self._flip( + results['gt_seg_map'], direction=results['do_flip']) + results['gt_seg_map'] = results['gt_seg_map'].squeeze() + # swap label pairs + if self.swap_label_pairs is not None: + results['gt_seg_map'] = self._swap_label(results['gt_seg_map']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, axes={self.axes}, ' \ + f'swap_label_pairs={self.swap_label_pairs})' + return repr_str diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index 0fd58218d6..661796147d 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -6,8 +6,8 @@ get_palette, isaid_classes, isaid_palette, loveda_classes, loveda_palette, potsdam_classes, potsdam_palette, stare_classes, stare_palette, - vaihingen_classes, vaihingen_palette, voc_classes, - voc_palette) + synapse_classes, synapse_palette, vaihingen_classes, + vaihingen_palette, voc_classes, voc_palette) # yapf: enable from .collect_env import collect_env from .io import datafrombytes @@ -27,5 +27,5 @@ 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', - 'datafrombytes' + 'datafrombytes', 'synapse_palette', 'synapse_classes' ] diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py index a62eaac973..662199f21e 100644 --- a/mmseg/utils/class_names.py +++ b/mmseg/utils/class_names.py @@ -265,6 +265,20 @@ def stare_palette(): return [[120, 120, 120], [6, 230, 230]] +def synapse_palette(): + """Synapse palette for external use.""" + return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255], + [255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]] + + +def synapse_classes(): + """Synapse class names for external use.""" + return [ + 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach' + ] + + def lip_classes(): """LIP class names for external use.""" return [ diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py index aa30893609..09d2349c15 100644 --- a/mmseg/utils/misc.py +++ b/mmseg/utils/misc.py @@ -98,6 +98,11 @@ def stack_batch(inputs: List[torch.Tensor], del data_sample.gt_sem_seg.data data_sample.gt_sem_seg.data = F.pad( gt_sem_seg, padding_size, value=seg_pad_val) + if 'gt_edge_map' in data_sample: + gt_edge_map = data_sample.gt_edge_map.data + del data_sample.gt_edge_map.data + data_sample.gt_edge_map.data = F.pad( + gt_edge_map, padding_size, value=seg_pad_val) data_sample.set_metainfo({ 'img_shape': tensor.shape[-2:], 'pad_shape': data_sample.gt_sem_seg.shape, diff --git a/mmseg/version.py b/mmseg/version.py index 6931108fe7..ae61f8bf7b 100644 --- a/mmseg/version.py +++ b/mmseg/version.py @@ -1,6 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. -__version__ = '1.0.0rc3' +__version__ = '1.0.0rc4' def parse_version_info(version_str): diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 070b06b73b..27443f2c57 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import mmcv import numpy as np @@ -9,6 +9,7 @@ from mmseg.registry import VISUALIZERS from mmseg.structures import SegDataSample +from mmseg.utils import get_classes, get_palette @VISUALIZERS.register_module() @@ -55,14 +56,23 @@ def __init__(self, image: Optional[np.ndarray] = None, vis_backends: Optional[Dict] = None, save_dir: Optional[str] = None, + palette: Optional[Union[str, List]] = None, + classes: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, alpha: float = 0.8, **kwargs): super().__init__(name, image, vis_backends, save_dir, **kwargs) - self.alpha = alpha + self.alpha: float = alpha # Set default value. When calling # `SegLocalVisualizer().dataset_meta=xxx`, # it will override the default value. - self.dataset_meta = {} + if dataset_name is None: + dataset_name = 'cityscapes' + classes = classes if classes else get_classes(dataset_name) + palette = palette if palette else get_palette(dataset_name) + assert len(classes) == len( + palette), 'The length of classes should be equal to palette' + self.dataset_meta: dict = {'classes': classes, 'palette': palette} def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, classes: Optional[Tuple[str]], diff --git a/projects/HieraSeg/README.md b/projects/HieraSeg/README.md new file mode 100644 index 0000000000..5519ec6916 --- /dev/null +++ b/projects/HieraSeg/README.md @@ -0,0 +1,93 @@ +# HieraSeg + +Support `Deep Hierarchical Semantic Segmentation` interface on `cityscapes` + +## Description + +Author: AI-Tianlong + +This project implements `HieraSeg` inference in the `cityscapes` dataset + +## Usage + +### Prerequisites + +- Python 3.8 +- PyTorch 1.6 or higher +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc3 +- mmcv v2.0.0rc3 +- mmengine + +### Dataset preparing + +preparing `cityscapes` dataset like this [structure](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets) + +### Testing commands + +please put [`hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth`](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) to `mmsegmentation/checkpoints` + +#### Multi-GPUs Test + +```bash +# --tta optional, multi-scale test, need mmengine >=0.4.0 +bash tools/dist_test.sh [configs] [model weights] [number of gpu] --tta +``` + +#### Example + +```shell +bash tools/dist_test.sh projects/HieraSeg_project/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py checkpoints/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth 2 --tta +``` + +## Results + +### Cityscapes + +| Method | Backbone | Crop Size | mIoU | mIoU (ms+flip) | config | model | +| :--------: | :------: | :-------: | :---: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: | +| DeeplabV3+ | R-101-D8 | 512x1024 | 81.61 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) | + + + +## Citation + +This project is modified from [qhanghu/HSSN_pytorch](https://github.com/qhanghu/HSSN_pytorch) + +```bibtex +@article{li2022deep, + title={Deep Hierarchical Semantic Segmentation}, + author={Li, Liulei and Zhou, Tianfei and Wang, Wenguan and Li, Jianwu and Yang, Yi}, + journal={CVPR}, + year={2022} +} +``` + +## Checklist + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + - [x] Basic docstrings & proper citation + + - [x] Test-time correctness + + - [x] A full README + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + - [ ] Unit tests + + - [ ] Code polishing + + - [ ] Metafile.yml + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/HieraSeg/configs/_base_/datasets/cityscapes.py b/projects/HieraSeg/configs/_base_/datasets/cityscapes.py new file mode 100644 index 0000000000..1698e04721 --- /dev/null +++ b/projects/HieraSeg/configs/_base_/datasets/cityscapes.py @@ -0,0 +1,67 @@ +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = 'data/cityscapes/' +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict( + type='RandomResize', + scale=(2048, 1024), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='leftImg8bit/train', seg_map_path='gtFine/train'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='leftImg8bit/val', seg_map_path='gtFine/val'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/projects/HieraSeg/configs/_base_/default_runtime.py b/projects/HieraSeg/configs/_base_/default_runtime.py new file mode 100644 index 0000000000..272b4d2467 --- /dev/null +++ b/projects/HieraSeg/configs/_base_/default_runtime.py @@ -0,0 +1,15 @@ +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') +log_processor = dict(by_epoch=False) +log_level = 'INFO' +load_from = None +resume = False + +tta_model = dict(type='SegTTAModel') diff --git a/projects/HieraSeg/configs/_base_/models/deeplabv3plus_r50-d8_vd_contrast.py b/projects/HieraSeg/configs/_base_/models/deeplabv3plus_r50-d8_vd_contrast.py new file mode 100644 index 0000000000..a6af45ce84 --- /dev/null +++ b/projects/HieraSeg/configs/_base_/models/deeplabv3plus_r50-d8_vd_contrast.py @@ -0,0 +1,55 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained=None, + backbone=dict( + type='ResNetV1d', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DepthwiseSeparableASPPContrastHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + c1_in_channels=256, + c1_channels=48, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + proj='convmlp', + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/projects/HieraSeg/configs/_base_/schedules/schedule_80k.py b/projects/HieraSeg/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 0000000000..0dcd6c4d1b --- /dev/null +++ b/projects/HieraSeg/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,24 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) +# learning policy +param_scheduler = [ + dict( + type='PolyLR', + eta_min=1e-4, + power=0.9, + begin=0, + end=80000, + by_epoch=False) +] +# training schedule for 80k +train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) diff --git a/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py b/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py new file mode 100644 index 0000000000..0d02bef5dc --- /dev/null +++ b/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py @@ -0,0 +1,21 @@ +_base_ = [ + '../_base_/models/deeplabv3plus_r50-d8_vd_contrast.py', + '../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] + +custom_imports = dict(imports=[ + 'projects.HieraSeg.decode_head.sep_aspp_contrast_head', + 'projects.HieraSeg.losses.hiera_triplet_loss_cityscape' +]) + +model = dict( + pretrained=None, + backbone=dict(depth=101), + decode_head=dict( + num_classes=26, + loss_decode=dict( + type='HieraTripletLossCityscape', num_classes=19, + loss_weight=1.0)), + auxiliary_head=dict(num_classes=19), + test_cfg=dict(mode='whole', is_hiera=True, hiera_num_classes=7)) diff --git a/projects/HieraSeg/decode_head/__init__.py b/projects/HieraSeg/decode_head/__init__.py new file mode 100644 index 0000000000..da454ea339 --- /dev/null +++ b/projects/HieraSeg/decode_head/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .sep_aspp_contrast_head import DepthwiseSeparableASPPContrastHead + +__all__ = ['DepthwiseSeparableASPPContrastHead'] diff --git a/projects/HieraSeg/decode_head/sep_aspp_contrast_head.py b/projects/HieraSeg/decode_head/sep_aspp_contrast_head.py new file mode 100644 index 0000000000..75f67e7457 --- /dev/null +++ b/projects/HieraSeg/decode_head/sep_aspp_contrast_head.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from torch import Tensor + +from mmseg.models.decode_heads.sep_aspp_head import DepthwiseSeparableASPPHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from mmseg.registry import MODELS + + +class ProjectionHead(nn.Module): + """ProjectionHead, project feature map to specific channels. + + Args: + dim_in (int): Input channels. + norm_cfg (dict): config of norm layer. + proj_dim (int): Output channels. Default: 256. + proj (str): Projection type, 'linear' or 'convmlp'. Default: 'convmlp' + """ + + def __init__(self, + dim_in: int, + norm_cfg: dict, + proj_dim: int = 256, + proj: str = 'convmlp'): + super().__init__() + assert proj in ['convmlp', 'linear'] + if proj == 'linear': + self.proj = nn.Conv2d(dim_in, proj_dim, kernel_size=1) + elif proj == 'convmlp': + self.proj = nn.Sequential( + nn.Conv2d(dim_in, dim_in, kernel_size=1), + build_norm_layer(norm_cfg, dim_in)[1], nn.ReLU(inplace=True), + nn.Conv2d(dim_in, proj_dim, kernel_size=1)) + + def forward(self, x): + return torch.nn.functional.normalize(self.proj(x), p=2, dim=1) + + +@MODELS.register_module() +class DepthwiseSeparableASPPContrastHead(DepthwiseSeparableASPPHead): + """Deep Hierarchical Semantic Segmentation. This head is the implementation + of ``_. + + Based on Encoder-Decoder with Atrous Separable Convolution for + Semantic Image Segmentation. + `DeepLabV3+ `_. + + Args: + proj (str): The type of ProjectionHead, 'linear' or 'convmlp', + default 'convmlp' + """ + + def __init__(self, proj: str = 'convmlp', **kwargs): + super().__init__(**kwargs) + self.proj_head = ProjectionHead( + dim_in=2048, norm_cfg=self.norm_cfg, proj=proj) + self.register_buffer('step', torch.zeros(1)) + + def forward(self, inputs): + """Forward function.""" + self.step += 1 + embedding = self.proj_head(inputs[-1]) + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + if self.c1_bottleneck is not None: + c1_output = self.c1_bottleneck(inputs[0]) + output = resize( + input=output, + size=c1_output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = torch.cat([output, c1_output], dim=1) + output = self.sep_bottleneck(output) + output = self.cls_seg(output) + return output, embedding + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + # HieraSeg decode_head output is: (out, embedding) :tuple, + # only need 'out' here. + if isinstance(seg_logits, tuple): + seg_logit = seg_logits[0] + + if seg_logit.size(1) == 26: + seg_logit[:, 0:2] += seg_logit[:, -7] + seg_logit[:, 2:5] += seg_logit[:, -6] + seg_logit[:, 5:8] += seg_logit[:, -5] + seg_logit[:, 8:10] += seg_logit[:, -4] + seg_logit[:, 10:11] += seg_logit[:, -3] + seg_logit[:, 11:13] += seg_logit[:, -2] + seg_logit[:, 13:19] += seg_logit[:, -1] + elif seg_logit.size(1) == 12: + seg_logit[:, 0:1] = seg_logit[:, 0:1] + \ + seg_logit[:, 7] + seg_logit[:, 10] + seg_logit[:, 1:5] = seg_logit[:, 1:5] + \ + seg_logit[:, 8] + seg_logit[:, 11] + seg_logit[:, 5:7] = seg_logit[:, 5:7] + \ + seg_logit[:, 9] + seg_logit[:, 11] + elif seg_logit.size(1) == 25: + seg_logit[:, 0:1] = seg_logit[:, 0:1] + \ + seg_logit[:, 20] + seg_logit[:, 23] + seg_logit[:, 1:8] = seg_logit[:, 1:8] + \ + seg_logit[:, 21] + seg_logit[:, 24] + seg_logit[:, 10:12] = seg_logit[:, 10:12] + \ + seg_logit[:, 21] + seg_logit[:, 24] + seg_logit[:, 13:16] = seg_logit[:, 13:16] + \ + seg_logit[:, 21] + seg_logit[:, 24] + seg_logit[:, 8:10] = seg_logit[:, 8:10] + \ + seg_logit[:, 22] + seg_logit[:, 24] + seg_logit[:, 12:13] = seg_logit[:, 12:13] + \ + seg_logit[:, 22] + seg_logit[:, 24] + seg_logit[:, 16:20] = seg_logit[:, 16:20] + \ + seg_logit[:, 22] + seg_logit[:, 24] + + # seg_logit = seg_logit[:,:-self.test_cfg['hiera_num_classes']] + seg_logit = seg_logit[:, :-7] + seg_logit = resize( + input=seg_logit, + size=batch_img_metas[0]['img_shape'], + mode='bilinear', + align_corners=self.align_corners) + + return seg_logit + + def losses(self, results, seg_label): + """Compute segmentation loss.""" + seg_logit_before = results[0] + embedding = results[1] + loss = dict() + seg_logit = resize( + input=seg_logit_before, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + seg_logit_before = resize( + input=seg_logit_before, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + loss['loss_seg'] = self.loss_decode( + self.step, + embedding, + seg_logit_before, + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + loss['acc_seg'] = accuracy(seg_logit, seg_label) + return loss diff --git a/projects/HieraSeg/losses/__init__.py b/projects/HieraSeg/losses/__init__.py new file mode 100644 index 0000000000..47d2686482 --- /dev/null +++ b/projects/HieraSeg/losses/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hiera_triplet_loss_cityscape import HieraTripletLossCityscape + +__all__ = ['HieraTripletLossCityscape'] diff --git a/projects/HieraSeg/losses/hiera_triplet_loss_cityscape.py b/projects/HieraSeg/losses/hiera_triplet_loss_cityscape.py new file mode 100644 index 0000000000..a784f13e62 --- /dev/null +++ b/projects/HieraSeg/losses/hiera_triplet_loss_cityscape.py @@ -0,0 +1,218 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.models.builder import LOSSES +from mmseg.models.losses.cross_entropy_loss import CrossEntropyLoss +from .tree_triplet_loss import TreeTripletLoss + +hiera_map = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6] +hiera_index = [[0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]] + +hiera = { + 'hiera_high': { + 'flat': [0, 2], + 'construction': [2, 5], + 'object': [5, 8], + 'nature': [8, 10], + 'sky': [10, 11], + 'human': [11, 13], + 'vehicle': [13, 19] + } +} + + +def prepare_targets(targets): + b, h, w = targets.shape + targets_high = torch.ones( + (b, h, w), dtype=targets.dtype, device=targets.device) * 255 + indices_high = [] + for index, high in enumerate(hiera['hiera_high'].keys()): + indices = hiera['hiera_high'][high] + for ii in range(indices[0], indices[1]): + targets_high[targets == ii] = index + indices_high.append(indices) + + return targets, targets_high, indices_high + + +def losses_hiera(predictions, + targets, + targets_top, + num_classes, + indices_high, + eps=1e-8): + """Implementation of hiera loss. + + Args: + predictions (torch.Tensor): seg logits produced by decode head. + targets (torch.Tensor): The learning label of the prediction. + targets_top (torch.Tensor): The hierarchy ground truth of the learning + label. + num_classes (int): Number of categories. + indices_high (List[List[int]]): Hierarchy indices of each hierarchy. + eps (float):Term added to the Logarithm to improve numerical stability. + """ + b, _, h, w = predictions.shape + predictions = torch.sigmoid(predictions.float()) + void_indices = (targets == 255) + targets[void_indices] = 0 + targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2) + void_indices2 = (targets_top == 255) + targets_top[void_indices2] = 0 + targets_top = F.one_hot(targets_top, num_classes=7).permute(0, 3, 1, 2) + + MCMA = predictions[:, :num_classes, :, :] + MCMB = torch.zeros((b, 7, h, w)).to(predictions) + for ii in range(7): + MCMB[:, ii:ii + 1, :, :] = torch.max( + torch.cat([ + predictions[:, indices_high[ii][0]:indices_high[ii][1], :, :], + predictions[:, num_classes + ii:num_classes + ii + 1, :, :] + ], + dim=1), 1, True)[0] + + MCLB = predictions[:, num_classes:num_classes + 7, :, :] + MCLA = predictions[:, :num_classes, :, :].clone() + for ii in range(7): + for jj in range(indices_high[ii][0], indices_high[ii][1]): + MCLA[:, jj:jj + 1, :, :] = torch.min( + torch.cat([ + predictions[:, jj:jj + 1, :, :], MCLB[:, ii:ii + 1, :, :] + ], + dim=1), 1, True)[0] + + valid_indices = (~void_indices).unsqueeze(1) + num_valid = valid_indices.sum() + valid_indices2 = (~void_indices2).unsqueeze(1) + num_valid2 = valid_indices2.sum() + # channel_num*sum()/one_channel_valid already has a weight + loss = ( + (-targets[:, :num_classes, :, :] * torch.log(MCLA + eps) - + (1.0 - targets[:, :num_classes, :, :]) * torch.log(1.0 - MCMA + eps)) + * valid_indices).sum() / num_valid / num_classes + loss += ((-targets_top[:, :, :, :] * torch.log(MCLB + eps) - + (1.0 - targets_top[:, :, :, :]) * torch.log(1.0 - MCMB + eps)) * + valid_indices2).sum() / num_valid2 / 7 + + return 5 * loss + + +def losses_hiera_focal(predictions, + targets, + targets_top, + num_classes, + indices_high, + eps=1e-8, + gamma=2): + """Implementation of hiera loss. + + Args: + predictions (torch.Tensor): seg logits produced by decode head. + targets (torch.Tensor): The learning label of the prediction. + targets_top (torch.Tensor): The hierarchy ground truth of the learning + label. + num_classes (int): Number of categories. + indices_high (List[List[int]]): Hierarchy indices of each hierarchy. + eps (float):Term added to the Logarithm to improve numerical stability. + Defaults: 1e-8. + gamma (int): The exponent value. Defaults: 2. + """ + b, _, h, w = predictions.shape + predictions = torch.sigmoid(predictions.float()) + void_indices = (targets == 255) + targets[void_indices] = 0 + targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2) + void_indices2 = (targets_top == 255) + targets_top[void_indices2] = 0 + targets_top = F.one_hot(targets_top, num_classes=7).permute(0, 3, 1, 2) + + MCMA = predictions[:, :num_classes, :, :] + MCMB = torch.zeros((b, 7, h, w), + dtype=predictions.dtype, + device=predictions.device) + for ii in range(7): + MCMB[:, ii:ii + 1, :, :] = torch.max( + torch.cat([ + predictions[:, indices_high[ii][0]:indices_high[ii][1], :, :], + predictions[:, num_classes + ii:num_classes + ii + 1, :, :] + ], + dim=1), 1, True)[0] + + MCLB = predictions[:, num_classes:num_classes + 7, :, :] + MCLA = predictions[:, :num_classes, :, :].clone() + for ii in range(7): + for jj in range(indices_high[ii][0], indices_high[ii][1]): + MCLA[:, jj:jj + 1, :, :] = torch.min( + torch.cat([ + predictions[:, jj:jj + 1, :, :], MCLB[:, ii:ii + 1, :, :] + ], + dim=1), 1, True)[0] + + valid_indices = (~void_indices).unsqueeze(1) + num_valid = valid_indices.sum() + valid_indices2 = (~void_indices2).unsqueeze(1) + num_valid2 = valid_indices2.sum() + # channel_num*sum()/one_channel_valid already has a weight + loss = ((-targets[:, :num_classes, :, :] * torch.pow( + (1.0 - MCLA), gamma) * torch.log(MCLA + eps) - + (1.0 - targets[:, :num_classes, :, :]) * torch.pow(MCMA, gamma) * + torch.log(1.0 - MCMA + eps)) * + valid_indices).sum() / num_valid / num_classes + loss += ( + (-targets_top[:, :, :, :] * torch.pow( + (1.0 - MCLB), gamma) * torch.log(MCLB + eps) - + (1.0 - targets_top[:, :, :, :]) * torch.pow(MCMB, gamma) * + torch.log(1.0 - MCMB + eps)) * valid_indices2).sum() / num_valid2 / 7 + + return 5 * loss + + +@LOSSES.register_module() +class HieraTripletLossCityscape(nn.Module): + """Modified from https://github.com/qhanghu/HSSN_pytorch/blob/main/mmseg/mo + dels/losses/hiera_triplet_loss_cityscape.py.""" + + def __init__(self, num_classes, use_sigmoid=False, loss_weight=1.0): + super().__init__() + self.num_classes = num_classes + self.loss_weight = loss_weight + self.treetripletloss = TreeTripletLoss(num_classes, hiera_map, + hiera_index) + self.ce = CrossEntropyLoss() + + def forward(self, + step, + embedding, + cls_score_before, + cls_score, + label, + weight=None, + **kwargs): + targets, targets_top, indices_top = prepare_targets(label) + + loss = losses_hiera(cls_score, targets, targets_top, self.num_classes, + indices_top) + ce_loss = self.ce(cls_score[:, :-7], label) + ce_loss2 = self.ce(cls_score[:, -7:], targets_top) + loss = loss + ce_loss + ce_loss2 + + loss_triplet, class_count = self.treetripletloss(embedding, label) + class_counts = [ + torch.ones_like(class_count) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(class_counts, class_count, async_op=False) + class_counts = torch.cat(class_counts, dim=0) + + if torch.distributed.get_world_size() == torch.nonzero( + class_counts, as_tuple=False).size(0): + factor = 1 / 4 * (1 + torch.cos( + torch.tensor((step.item() - 80000) / 80000 * + math.pi))) if step.item() < 80000 else 0.5 + loss += factor * loss_triplet + + return loss * self.loss_weight diff --git a/projects/HieraSeg/losses/tree_triplet_loss.py b/projects/HieraSeg/losses/tree_triplet_loss.py new file mode 100644 index 0000000000..ccc0937405 --- /dev/null +++ b/projects/HieraSeg/losses/tree_triplet_loss.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.models.builder import LOSSES + + +@LOSSES.register_module() +class TreeTripletLoss(nn.Module): + """TreeTripletLoss. Modified from https://github.com/qhanghu/HSSN_pytorch/b + lob/main/mmseg/models/losses/tree_triplet_loss.py. + + Args: + num_classes (int): Number of categories. + hiera_map (List[int]): Hierarchy map of each category. + hiera_index (List[List[int]]): Hierarchy indices of each hierarchy. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. Defaults: 255. + + Examples: + >>> num_classes = 19 + >>> hiera_map = [ + 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6] + >>> hiera_index = [ + 0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]] + """ + + def __init__(self, num_classes, hiera_map, hiera_index, ignore_index=255): + super().__init__() + + self.ignore_label = ignore_index + self.num_classes = num_classes + self.hiera_map = hiera_map + self.hiera_index = hiera_index + + def forward(self, feats: torch.Tensor, labels=None, max_triplet=200): + labels = labels.unsqueeze(1).float().clone() + labels = torch.nn.functional.interpolate( + labels, (feats.shape[2], feats.shape[3]), mode='nearest') + labels = labels.squeeze(1).long() + assert labels.shape[-1] == feats.shape[-1], '{} {}'.format( + labels.shape, feats.shape) + + labels = labels.view(-1) + feats = feats.permute(0, 2, 3, 1) + feats = feats.contiguous().view(-1, feats.shape[-1]) + + triplet_loss = 0 + exist_classes = torch.unique(labels) + exist_classes = [x for x in exist_classes if x != 255] + class_count = 0 + + for ii in exist_classes: + index_range = self.hiera_index[self.hiera_map[ii]] + index_anchor = labels == ii + index_pos = (labels >= index_range[0]) & ( + labels < index_range[-1]) & (~index_anchor) + index_neg = (labels < index_range[0]) | (labels >= index_range[-1]) + + min_size = min( + torch.sum(index_anchor), torch.sum(index_pos), + torch.sum(index_neg), max_triplet) + + feats_anchor = feats[index_anchor][:min_size] + feats_pos = feats[index_pos][:min_size] + feats_neg = feats[index_neg][:min_size] + + distance = torch.zeros(min_size, 2).to(feats) + distance[:, 0:1] = 1 - (feats_anchor * feats_pos).sum(1, True) + distance[:, 1:2] = 1 - (feats_anchor * feats_neg).sum(1, True) + + # margin always 0.1 + (4-2)/4 since the hierarchy is three level + # TODO: should include label of pos is the same as anchor + margin = 0.6 * torch.ones(min_size).to(feats) + + tl = distance[:, 0] - distance[:, 1] + margin + tl = F.relu(tl) + + if tl.size(0) > 0: + triplet_loss += tl.mean() + class_count += 1 + if class_count == 0: + return None, torch.tensor([0]).to(feats) + triplet_loss /= class_count + return triplet_loss, torch.tensor([class_count]).to(feats) diff --git a/projects/isnet/README.md b/projects/isnet/README.md new file mode 100644 index 0000000000..3a3172a9d9 --- /dev/null +++ b/projects/isnet/README.md @@ -0,0 +1,117 @@ +# ISNet + +[ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation](https://arxiv.org/pdf/2108.12382.pdf) + +## Description + +This is an implementation of [ISNet](https://arxiv.org/pdf/2108.12382.pdf). +[Official Repo](https://github.com/SegmentationBLWX/sssegmentation) + +## Usage + +### Prerequisites + +- Python 3.7 +- PyTorch 1.6 or higher +- [MIM](https://github.com/open-mmlab/mim) v0.33 or higher +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc2 or higher + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `isnet/` root directory, run the following line to add the current directory to `PYTHONPATH`: + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Training commands + +```shell +mim train mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet +``` + +To train on multiple GPUs, e.g. 8 GPUs, run the following command: + +```shell +mim train mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet --launcher pytorch --gpus 8 +``` + +### Testing commands + +```shell +mim test mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet --checkpoint ${CHECKPOINT_PATH} +``` + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ | +| ISNet | R-50-D8 | 512x1024 | - | - | - | 79.32 | 80.88 | [config](configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/isnet/isnet_r50-d8_cityscapes-512x1024_20230104-a7a8ccf2.pth) | + +## Citation + +```bibtex +@article{Jin2021ISNetII, + title={ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation}, + author={Zhenchao Jin and B. Liu and Qi Chu and Nenghai Yu}, + journal={2021 IEEE/CVF International Conference on Computer Vision (ICCV)}, + year={2021}, + pages={7169-7178} +} +``` + +## Checklist + +The progress of ISNet. + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Test-time correctness + + + + - [x] A full README + + + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] Metafile.yml + + + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py b/projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py new file mode 100644 index 0000000000..a00d39237d --- /dev/null +++ b/projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py @@ -0,0 +1,80 @@ +_base_ = [ + '../../../configs/_base_/datasets/cityscapes.py', + '../../../configs/_base_/default_runtime.py', + '../../../configs/_base_/schedules/schedule_80k.py' +] + +data_root = '../../data/cityscapes/' +train_dataloader = dict(dataset=dict(data_root=data_root)) +val_dataloader = dict(dataset=dict(data_root=data_root)) +test_dataloader = dict(dataset=dict(data_root=data_root)) + +custom_imports = dict(imports=['projects.isnet.decode_heads']) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) + +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ISNetHead', + in_channels=(256, 512, 1024, 2048), + input_transform='multiple_select', + in_index=(0, 1, 2, 3), + channels=512, + dropout_ratio=0.1, + transform_channels=256, + concat_input=True, + with_shortcut=False, + shortcut_in_channels=256, + shortcut_feat_channels=48, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=[ + dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + loss_name='loss_o'), + dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.4, + loss_name='loss_d'), + ]), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=512, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + # test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513)) + test_cfg=dict(mode='whole')) diff --git a/projects/isnet/decode_heads/__init__.py b/projects/isnet/decode_heads/__init__.py new file mode 100644 index 0000000000..a451629c4c --- /dev/null +++ b/projects/isnet/decode_heads/__init__.py @@ -0,0 +1,3 @@ +from .isnet_head import ISNetHead + +__all__ = ['ISNetHead'] diff --git a/projects/isnet/decode_heads/isnet_head.py b/projects/isnet/decode_heads/isnet_head.py new file mode 100644 index 0000000000..9c8df540ee --- /dev/null +++ b/projects/isnet/decode_heads/isnet_head.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import SelfAttentionBlock, resize +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +class ImageLevelContext(nn.Module): + """ Image-Level Context Module + Args: + feats_channels (int): Input channels of query/key feature. + transform_channels (int): Output channels of key/query transform. + concat_input (bool): whether to concat input feature. + align_corners (bool): align_corners argument of F.interpolate. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, + feats_channels, + transform_channels, + concat_input=False, + align_corners=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.align_corners = align_corners + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.correlate_net = SelfAttentionBlock( + key_in_channels=feats_channels * 2, + query_in_channels=feats_channels, + channels=transform_channels, + out_channels=feats_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + value_out_num_convs=1, + key_query_norm=True, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + if concat_input: + self.bottleneck = ConvModule( + feats_channels * 2, + feats_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + '''forward''' + + def forward(self, x): + x_global = self.global_avgpool(x) + x_global = resize( + x_global, + size=x.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + feats_il = self.correlate_net(x, torch.cat([x_global, x], dim=1)) + if hasattr(self, 'bottleneck'): + feats_il = self.bottleneck(torch.cat([x, feats_il], dim=1)) + return feats_il + + +class SemanticLevelContext(nn.Module): + """ Semantic-Level Context Module + Args: + feats_channels (int): Input channels of query/key feature. + transform_channels (int): Output channels of key/query transform. + concat_input (bool): whether to concat input feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, + feats_channels, + transform_channels, + concat_input=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.correlate_net = SelfAttentionBlock( + key_in_channels=feats_channels, + query_in_channels=feats_channels, + channels=transform_channels, + out_channels=feats_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + value_out_num_convs=1, + key_query_norm=True, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + if concat_input: + self.bottleneck = ConvModule( + feats_channels * 2, + feats_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + '''forward''' + + def forward(self, x, preds, feats_il): + inputs = x + batch_size, num_channels, h, w = x.size() + num_classes = preds.size(1) + feats_sl = torch.zeros(batch_size, h * w, num_channels).type_as(x) + for batch_idx in range(batch_size): + # (C, H, W), (num_classes, H, W) --> (H*W, C), (H*W, num_classes) + feats_iter, preds_iter = x[batch_idx], preds[batch_idx] + feats_iter, preds_iter = feats_iter.reshape( + num_channels, -1), preds_iter.reshape(num_classes, -1) + feats_iter, preds_iter = feats_iter.permute(1, + 0), preds_iter.permute( + 1, 0) + # (H*W, ) + argmax = preds_iter.argmax(1) + for clsid in range(num_classes): + mask = (argmax == clsid) + if mask.sum() == 0: + continue + feats_iter_cls = feats_iter[mask] + preds_iter_cls = preds_iter[:, clsid][mask] + weight = torch.softmax(preds_iter_cls, dim=0) + feats_iter_cls = feats_iter_cls * weight.unsqueeze(-1) + feats_iter_cls = feats_iter_cls.sum(0) + feats_sl[batch_idx][mask] = feats_iter_cls + feats_sl = feats_sl.reshape(batch_size, h, w, num_channels) + feats_sl = feats_sl.permute(0, 3, 1, 2).contiguous() + feats_sl = self.correlate_net(inputs, feats_sl) + if hasattr(self, 'bottleneck'): + feats_sl = self.bottleneck(torch.cat([feats_il, feats_sl], dim=1)) + return feats_sl + + +@MODELS.register_module() +class ISNetHead(BaseDecodeHead): + """ISNet: Integrate Image-Level and Semantic-Level + Context for Semantic Segmentation + + This head is the implementation of `ISNet` + `_. + + Args: + transform_channels (int): Output channels of key/query transform. + concat_input (bool): whether to concat input feature. + with_shortcut (bool): whether to use shortcut connection. + shortcut_in_channels (int): Input channels of shortcut. + shortcut_feat_channels (int): Output channels of shortcut. + dropout_ratio (float): Ratio of dropout. + """ + + def __init__(self, transform_channels, concat_input, with_shortcut, + shortcut_in_channels, shortcut_feat_channels, dropout_ratio, + **kwargs): + super().__init__(**kwargs) + + self.in_channels = self.in_channels[-1] + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.ilc_net = ImageLevelContext( + feats_channels=self.channels, + transform_channels=transform_channels, + concat_input=concat_input, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.slc_net = SemanticLevelContext( + feats_channels=self.channels, + transform_channels=transform_channels, + concat_input=concat_input, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.decoder_stage1 = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Dropout2d(dropout_ratio), + nn.Conv2d( + self.channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True), + ) + + if with_shortcut: + self.shortcut = ConvModule( + shortcut_in_channels, + shortcut_feat_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.decoder_stage2 = nn.Sequential( + ConvModule( + self.channels + shortcut_feat_channels, + self.channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Dropout2d(dropout_ratio), + nn.Conv2d( + self.channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True), + ) + else: + self.decoder_stage2 = nn.Sequential( + nn.Dropout2d(dropout_ratio), + nn.Conv2d( + self.channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True), + ) + + self.conv_seg = None + self.dropout = None + + def forward(self, inputs): + x = self._transform_inputs(inputs) + feats = self.bottleneck(x[-1]) + + feats_il = self.ilc_net(feats) + + preds_stage1 = self.decoder_stage1(feats) + preds_stage1 = resize( + preds_stage1, + size=feats.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + + feats_sl = self.slc_net(feats, preds_stage1, feats_il) + + if hasattr(self, 'shortcut'): + shortcut_out = self.shortcut(x[0]) + feats_sl = resize( + feats_sl, + size=shortcut_out.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + feats_sl = torch.cat([feats_sl, shortcut_out], dim=1) + preds_stage2 = self.decoder_stage2(feats_sl) + + return preds_stage1, preds_stage2 + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits[-1], seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + for seg_logit, loss_decode in zip(seg_logits, self.loss_decode): + seg_logit = resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + loss[loss_decode.name] = loss_decode( + seg_logit, + seg_label, + seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logits[-1], seg_label, ignore_index=self.ignore_index) + return loss + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + _, seg_logits_stage2 = seg_logits + return super().predict_by_feat(seg_logits_stage2, batch_img_metas) diff --git a/projects/mapillary_dataset/README.md b/projects/mapillary_dataset/README.md new file mode 100644 index 0000000000..2b3099522e --- /dev/null +++ b/projects/mapillary_dataset/README.md @@ -0,0 +1,85 @@ +# Mapillary Vistas Dataset + +Support **`Mapillary Vistas Dataset`** + +## Description + +Author: AI-Tianlong + +This project implements **`Mapillary Vistas Dataset`** + +### Dataset preparing + +Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md) + +```none + mmsegmentation + ├── mmseg + ├── tools + ├── configs + ├── data + │ ├── mapillary + │ │ ├── training + │ │ │ ├── images + │ │ │ ├── v1.2 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │   │   │ └── panoptic + │ │ │ ├── v2.0 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │ │ │ ├── panoptic + | │   │   │ └── polygons + │ │ ├── validation + │ │ │ ├── images + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │   │   │ └── panoptic + │ │ │ ├── v2.0 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │ │ │ ├── panoptic + | │   │   │ └── polygons +``` + +### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py` + +```bash +# Dataset train commands +# at `mmsegmentation` folder +bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4 +``` + +## Checklist + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + - [x] Basic docstrings & proper citation + + - [ ] Test-time correctness + + - [x] A full README + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + - [ ] Unit tests + + - [ ] Code polishing + + - [ ] Metafile.yml + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/mapillary_dataset/configs/_base_/datasets/mapillary_v1_2.py b/projects/mapillary_dataset/configs/_base_/datasets/mapillary_v1_2.py new file mode 100644 index 0000000000..a0e7d14b52 --- /dev/null +++ b/projects/mapillary_dataset/configs/_base_/datasets/mapillary_v1_2.py @@ -0,0 +1,69 @@ +# dataset settings +dataset_type = 'MapillaryDataset_v1_2' +data_root = 'data/mapillary/' +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict( + type='RandomResize', + scale=(2048, 1024), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] +train_dataloader = dict( + batch_size=2, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='training/images', + seg_map_path='training/v1.2/labels_mask'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='validation/images', + seg_map_path='validation/v1.2/labels_mask'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/projects/mapillary_dataset/configs/_base_/datasets/mapillary_v2_0.py b/projects/mapillary_dataset/configs/_base_/datasets/mapillary_v2_0.py new file mode 100644 index 0000000000..7332d43fad --- /dev/null +++ b/projects/mapillary_dataset/configs/_base_/datasets/mapillary_v2_0.py @@ -0,0 +1,69 @@ +# dataset settings +dataset_type = 'MapillaryDataset_v2_0' +data_root = 'data/mapillary/' +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict( + type='RandomResize', + scale=(2048, 1024), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] +train_dataloader = dict( + batch_size=2, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='training/images', + seg_map_path='training/v2.0/labels_mask'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='validation/images', + seg_map_path='validation/v2.0/labels_mask'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py b/projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py new file mode 100644 index 0000000000..6f7ad65ed8 --- /dev/null +++ b/projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py @@ -0,0 +1,103 @@ +_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels +# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels +custom_imports = dict(imports=[ + 'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2', + 'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0', +]) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size=(512, 1024)) + +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained=None, + backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DepthwiseSeparableASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + c1_in_channels=256, + c1_channels=48, + dropout_ratio=0.1, + num_classes=66, # v1.2 + # num_classes=124, # v2.0 + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=66, # v1.2 + # num_classes=124, # v2.0 + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='SegLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer') +log_processor = dict(by_epoch=False) +log_level = 'INFO' +load_from = None +resume = False +tta_model = dict(type='SegTTAModel') +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), + clip_grad=None) +param_scheduler = [ + dict( + type='PolyLR', + eta_min=0.0001, + power=0.9, + begin=0, + end=240000, + by_epoch=False) +] +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=240000, val_interval=24000) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) diff --git a/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md b/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md new file mode 100644 index 0000000000..405e533156 --- /dev/null +++ b/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md @@ -0,0 +1,117 @@ +## Prepare datasets + +It is recommended to symlink the dataset root to `$MMSEGMENTATION/data`. +If your folder structure is different, you may need to change the corresponding paths in config files. + +```none +mmsegmentation +├── mmseg +├── tools +├── configs +├── data +│ ├── mapillary +│ │ ├── training +│ │ │ ├── images +│ │ │ ├── v1.2 +| │ │ │ ├── instances +| │ │ │ ├── labels +| │ │ │ ├── labels_mask +| │   │   │ └── panoptic +│ │ │ ├── v2.0 +| │ │ │ ├── instances +| │ │ │ ├── labels +| │ │ │ ├── labels_mask +| │ │ │ ├── panoptic +| │   │   │ └── polygons +│ │ ├── validation +│ │ │ ├── images +| │ │ │ ├── instances +| │ │ │ ├── labels +| │ │ │ ├── labels_mask +| │   │   │ └── panoptic +│ │ │ ├── v2.0 +| │ │ │ ├── instances +| │ │ │ ├── labels +| │ │ │ ├── labels_mask +| │ │ │ ├── panoptic +| │   │   │ └── polygons +``` + +## Mapillary Vistas Datasets + +- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration. +- Assumption you have put the dataset zip file in `mmsegmentation/data` +- Please run the following commands to unzip dataset. + ```bash + cd data + mkdir mapillary + unzip -d mapillary An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip + ``` +- After unzip, you will get Mapillary Vistas Dataset like this structure. + ```none + ├── data + │ ├── mapillary + │ │ ├── training + │ │ │ ├── images + │ │ │ ├── v1.2 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │   │   │ └── panoptic + │ │ │ ├── v2.0 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── panoptic + | │   │   │ └── polygons + │ │ ├── validation + │ │ │ ├── images + | │ │ │ ├── instances + | │ │ │ ├── labels + | │   │   │ └── panoptic + │ │ │ ├── v2.0 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── panoptic + | │   │   │ └── polygons + ``` +- run following commands to convert RGB labels to mask labels + ```bash + # --nproc optional, default 1, whether use multi-progress + # --version optional, 'v1.2', 'v2.0','all', default 'all', choose convert which version labels + # run this command at 'mmsegmentation/projects/Mapillary_dataset' folder + cd mmsegmentation/projects/mapillary_dataset + python tools/dataset_converters/mapillary.py ../../data/mapillary --nproc 8 --version all + ``` + After then, you will get this structure + ```none + mmsegmentation + ├── mmseg + ├── tools + ├── configs + ├── data + │ ├── mapillary + │ │ ├── training + │ │ │ ├── images + │ │ │ ├── v1.2 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │   │   │ └── panoptic + │ │ │ ├── v2.0 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │ │ │ ├── panoptic + | │   │   │ └── polygons + │ │ ├── validation + │ │ │ ├── images + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │   │   │ └── panoptic + │ │ │ ├── v2.0 + | │ │ │ ├── instances + | │ │ │ ├── labels + | │ │ │ ├── labels_mask + | │ │ │ ├── panoptic + | │   │   │ └── polygons + ``` diff --git a/projects/mapillary_dataset/mmseg/datasets/mapillary_v1_2.py b/projects/mapillary_dataset/mmseg/datasets/mapillary_v1_2.py new file mode 100644 index 0000000000..975d07b24e --- /dev/null +++ b/projects/mapillary_dataset/mmseg/datasets/mapillary_v1_2.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.datasets.basesegdataset import BaseSegDataset +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class MapillaryDataset_v1_2(BaseSegDataset): + """Mapillary Vistas Dataset. + + Dataset paper link: + http://ieeexplore.ieee.org/document/8237796/ + + v1.2 contain 66 object classes. + (37 instance-specific) + + v2.0 contain 124 object classes. + (70 instance-specific, 46 stuff, 8 void or crowd). + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png' for Mapillary Vistas Dataset. + """ + METAINFO = dict( + classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', + 'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain', + 'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track', + 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', + 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Crosswalk', + 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', + 'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', + 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', + 'Phone Booth', 'Pothole', 'Street Light', 'Pole', + 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', + 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', + 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle', + 'On Rails', 'Other Vehicle', 'Trailer', 'Truck', + 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'), + palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [102, 102, 156], + [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], + [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], + [200, 128, 128], [255, 255, 255], [64, 170, + 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 220, 220], [220, 128, 128], + [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], + [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], + [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], + [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], + [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, + 10], [0, 0, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/projects/mapillary_dataset/mmseg/datasets/mapillary_v2_0.py b/projects/mapillary_dataset/mmseg/datasets/mapillary_v2_0.py new file mode 100644 index 0000000000..9c67a8b212 --- /dev/null +++ b/projects/mapillary_dataset/mmseg/datasets/mapillary_v2_0.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.datasets.basesegdataset import BaseSegDataset +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class MapillaryDataset_v2_0(BaseSegDataset): + """Mapillary Vistas Dataset. + + Dataset paper link: + http://ieeexplore.ieee.org/document/8237796/ + + v1.2 contain 66 object classes. + (37 instance-specific) + + v2.0 contain 124 object classes. + (70 instance-specific, 46 stuff, 8 void or crowd). + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png' for Mapillary Vistas Dataset. + """ + METAINFO = dict( + classes=( + 'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', + 'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median', + 'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall', + 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway', + 'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track', + 'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk', + 'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel', + 'Person', 'Person Group', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Dashed Line', + 'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line', + 'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)', + 'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)', + 'Lane Marking - Arrow (Split Left or Straight)', + 'Lane Marking - Arrow (Split Right or Straight)', + 'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk', + 'Lane Marking - Give Way (Row)', + 'Lane Marking - Give Way (Single)', + 'Lane Marking - Hatched (Chevron)', + 'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other', + 'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)', + 'Lane Marking - Symbol (Other)', 'Lane Marking - Text', + 'Lane Marking (only) - Dashed Line', + 'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other', + 'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow', + 'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack', + 'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box', + 'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole', + 'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back', + 'Signage - Information', 'Signage - Other', 'Signage - Store', + 'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame', + 'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)', + 'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)', + 'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists', + 'Traffic Light - Other', 'Traffic Sign - Ambiguous', + 'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)', + 'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)', + 'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)', + 'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', + 'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve', + 'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', + 'Unlabeled'), + palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], + [196, 196, 196], [190, 153, 153], [180, 165, 180], + [90, 120, 150], [250, 170, 33], [250, 170, 34], + [128, 128, 128], [250, 170, 35], [102, 102, 156], + [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [110, 110, 110], [244, 35, 232], [128, 196, + 128], [150, 100, 100], + [70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60], + [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], + [255, 255, 255], [255, 255, 255], [250, 170, 29], + [250, 170, 28], [250, 170, 26], [250, 170, + 25], [250, 170, 24], + [250, 170, 22], [250, 170, 21], [250, 170, + 20], [255, 255, 255], + [250, 170, 19], [250, 170, 18], [250, 170, + 12], [250, 170, 11], + [255, 255, 255], [255, 255, 255], [250, 170, 16], + [250, 170, 15], [250, 170, 15], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [255, 255, 255], + [64, 170, 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 128, 128], [222, 40, + 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], + [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], + [250, 173, 30], [250, 174, 30], [250, 175, + 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], + [128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, + 30], [250, 170, 30], + [250, 170, 30], [192, 192, 192], [192, 192, 192], + [192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196], + [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], + [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], + [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/projects/mapillary_dataset/tools/dataset_converters/mapillary.py b/projects/mapillary_dataset/tools/dataset_converters/mapillary.py new file mode 100644 index 0000000000..3ccb2d67b3 --- /dev/null +++ b/projects/mapillary_dataset/tools/dataset_converters/mapillary.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from functools import partial + +import mmcv +import numpy as np +from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress, + track_progress) + +colormap_v1_2 = np.array([[165, 42, 42], [0, 192, 0], [196, 196, 196], + [190, 153, 153], [180, 165, 180], [90, 120, 150], + [102, 102, 156], [128, 64, 255], [140, 140, 200], + [170, 170, 170], [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [244, 35, 232], [150, 100, 100], [70, 70, 70], + [150, 120, 90], [220, 20, 60], [255, 0, 0], + [255, 0, 100], [255, 0, 200], [200, 128, 128], + [255, 255, 255], [64, 170, 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], + [250, 0, 30], [100, 140, 180], [220, 220, 220], + [220, 128, 128], [222, 40, 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], + [142, 0, 0], [70, 100, 150], [210, 170, 100], + [153, 153, 153], [128, 128, 128], [0, 0, 80], + [250, 170, 30], [192, 192, 192], [220, 220, 0], + [140, 140, 20], [119, 11, 32], [150, 0, 255], + [0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230], + [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], + [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]) + +colormap_v2_0 = np.array([[165, 42, 42], [0, 192, 0], [250, 170, 31], + [250, 170, 32], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [250, 170, 33], + [250, 170, 34], [128, 128, 128], [250, 170, 35], + [102, 102, 156], [128, 64, 255], [140, 140, 200], + [170, 170, 170], [250, 170, 36], [250, 170, 160], + [250, 170, 37], [96, 96, 96], [230, 150, 140], + [128, 64, 128], [110, 110, 110], [110, 110, 110], + [244, 35, 232], [128, 196, 128], [150, 100, 100], + [70, 70, 70], [150, 150, 150], [150, 120, 90], + [220, 20, 60], [220, 20, 60], [255, 0, 0], + [255, 0, 100], [255, 0, 200], [255, 255, 255], + [255, 255, 255], [250, 170, 29], [250, 170, 28], + [250, 170, 26], [250, 170, 25], [250, 170, 24], + [250, 170, 22], [250, 170, 21], [250, 170, 20], + [255, 255, 255], [250, 170, 19], [250, 170, 18], + [250, 170, 12], [250, 170, 11], [255, 255, 255], + [255, 255, 255], [250, 170, 16], [250, 170, 15], + [250, 170, 15], [255, 255, 255], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [64, 170, 64], + [230, 160, 50], [70, 130, 180], [190, 255, 255], + [152, 251, 152], [107, 142, 35], [0, 170, 30], + [255, 255, 128], [250, 0, 30], [100, 140, 180], + [220, 128, 128], [222, 40, 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], + [20, 20, 255], [142, 0, 0], [70, 100, 150], + [250, 171, 30], [250, 172, 30], [250, 173, 30], + [250, 174, 30], [250, 175, 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], + [128, 128, 128], [0, 0, 80], [210, 60, 60], + [250, 170, 30], [250, 170, 30], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, 30], + [192, 192, 192], [192, 192, 192], [192, 192, 192], + [220, 220, 0], [220, 220, 0], [0, 0, 196], + [192, 192, 192], [220, 220, 0], [140, 140, 20], + [119, 11, 32], [150, 0, 255], [0, 60, 100], + [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], + [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 142], + [0, 0, 192], [170, 170, 170], [32, 32, 32], + [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]]) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert Mapillary dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='Mapillary folder path') + parser.add_argument( + '--version', + default='all', + help="Mapillary labels version, 'v1.2','v2.0','all'") + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def mapillary_colormap2label(colormap: np.ndarray) -> list: + """Create a `list` shaped (256^3, 1), convert each color palette to a + number, which can use to find the correct label value. + + For example labels 0--Bird--[165, 42, 42] + (165*256 + 42) * 256 + 42 = 10824234 (This is list's index]) + `colormap2label[10824234] = 0` + + In converting, if a RGB pixel value is [165, 42, 42], + through colormap2label[10824234]-->can quickly find + this labels value is 0. + Through matrix multiply to compute a img is very fast. + + Args: + colormap (np.ndarray): Mapillary Vistas Dataset palette + + Returns: + list: values are mask labels, + indexes are palette's convert results.、 + """ + colormap2label = np.zeros(256**3, dtype=np.longlong) + for i, colormap_ in enumerate(colormap): + colormap2label[(colormap_[0] * 256 + colormap_[1]) * 256 + + colormap_[2]] = i + return colormap2label + + +def mapillary_masklabel(rgb_label: np.ndarray, + colormap2label: list) -> np.ndarray: + """Computing a img mask label through `colormap2label` get in + `mapillary_colormap2label(COLORMAP: np.ndarray)` + + Args: + rgb_label (np.array): a RGB labels img. + colormap2label (list): get in mapillary_colormap2label(colormap) + + Returns: + np.ndarray: mask labels array. + """ + colormap_ = rgb_label.astype('uint32') + idx = np.array((colormap_[:, :, 0] * 256 + colormap_[:, :, 1]) * 256 + + colormap_[:, :, 2]).astype('uint32') + return colormap2label[idx] + + +def RGB2Mask(rgb_label_path: str, colormap2label: list) -> None: + """Mapillary Vistas Dataset provide 8-bit with color-palette class-specific + labels for semantic segmentation. However, semantic segmentation needs + single channel mask labels. + + This code is about converting mapillary RGB labels + {traing,validation/v1.2,v2.0/labels} to mask labels + {{traing,validation/v1.2,v2.0/labels_mask} + + Args: + rgb_label_path (str): image absolute path. + dataset_version (str): v1.2 or v2.0 to choose color_map . + """ + rgb_label = mmcv.imread(rgb_label_path, channel_order='rgb') + + masks_label = mapillary_masklabel(rgb_label, colormap2label) + + mmcv.imwrite( + masks_label.astype(np.uint8), + rgb_label_path.replace('labels', 'labels_mask')) + + +def main(): + colormap2label_v1_2 = mapillary_colormap2label(colormap_v1_2) + colormap2label_v2_0 = mapillary_colormap2label(colormap_v2_0) + + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = dataset_path + else: + out_dir = args.out_dir + + RGB_labels_path = [] + RGB_labels_v1_2_path = [] + RGB_labels_v2_0_path = [] + print('Scanning labels path....') + for label_path in scandir(dataset_path, suffix='.png', recursive=True): + if 'labels' in label_path: + rgb_label_path = osp.join(dataset_path, label_path) + RGB_labels_path.append(rgb_label_path) + if 'v1.2' in label_path: + RGB_labels_v1_2_path.append(rgb_label_path) + elif 'v2.0' in label_path: + RGB_labels_v2_0_path.append(rgb_label_path) + + if args.version == 'all': + print(f'Totaly found {len(RGB_labels_path)} {args.version} RGB labels') + elif args.version == 'v1.2': + print(f'Found {len(RGB_labels_v1_2_path)} {args.version} RGB labels') + elif args.version == 'v2.0': + print(f'Found {len(RGB_labels_v2_0_path)} {args.version} RGB labels') + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'training', 'v1.2', 'labels_mask')) + mkdir_or_exist(osp.join(out_dir, 'validation', 'v1.2', 'labels_mask')) + mkdir_or_exist(osp.join(out_dir, 'training', 'v2.0', 'labels_mask')) + mkdir_or_exist(osp.join(out_dir, 'validation', 'v2.0', 'labels_mask')) + print('Directories Have Made...') + + if args.nproc > 1: + if args.version == 'all': + print('Converting v1.2 ....') + track_parallel_progress( + partial(RGB2Mask, colormap2label=colormap2label_v1_2), + RGB_labels_v1_2_path, + nproc=args.nproc) + print('Converting v2.0 ....') + track_parallel_progress( + partial(RGB2Mask, colormap2label=colormap2label_v2_0), + RGB_labels_v2_0_path, + nproc=args.nproc) + elif args.version == 'v1.2': + print('Converting v1.2 ....') + track_parallel_progress( + partial(RGB2Mask, colormap2label=colormap2label_v1_2), + RGB_labels_v1_2_path, + nproc=args.nproc) + elif args.version == 'v2.0': + print('Converting v2.0 ....') + track_parallel_progress( + partial(RGB2Mask, colormap2label=colormap2label_v2_0), + RGB_labels_v2_0_path, + nproc=args.nproc) + + else: + if args.version == 'all': + print('Converting v1.2 ....') + track_progress( + partial(RGB2Mask, colormap2label=colormap2label_v1_2), + RGB_labels_v1_2_path) + print('Converting v2.0 ....') + track_progress( + partial(RGB2Mask, colormap2label=colormap2label_v2_0), + RGB_labels_v2_0_path) + elif args.version == 'v1.2': + print('Converting v1.2 ....') + track_progress( + partial(RGB2Mask, colormap2label=colormap2label_v1_2), + RGB_labels_v1_2_path) + elif args.version == 'v2.0': + print('Converting v2.0 ....') + track_progress( + partial(RGB2Mask, colormap2label=colormap2label_v2_0), + RGB_labels_v2_0_path) + + print('Have convert Mapillary Vistas Datasets RGB labels to Mask labels!') + + +if __name__ == '__main__': + args = parse_args() + main() diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt index d27af8dd0f..2c8e9d6a22 100644 --- a/requirements/mminstall.txt +++ b/requirements/mminstall.txt @@ -1,4 +1,4 @@ mmcls>=1.0.0rc0 -mmcv>=2.0.0rc3,<2.1.0 -mmdet>=3.0.0rc4 +mmcv==2.0.0rc3 +mmdet==3.0.0rc5 mmengine>=0.1.0,<1.0.0 diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt index af6029b9ad..1b5d8443b4 100644 --- a/requirements/readthedocs.txt +++ b/requirements/readthedocs.txt @@ -1,5 +1,6 @@ -mmcv>=2.0.0rc0 -mmengine +mmcv>=2.0.0rc1,<2.1.0 +mmengine>=0.1.0,<1.0.0 prettytable +scipy torch torchvision diff --git a/tests/data/pseudo_synapse_dataset/ann_dir/case0005_slice000.png b/tests/data/pseudo_synapse_dataset/ann_dir/case0005_slice000.png new file mode 100644 index 0000000000..a22059b58e Binary files /dev/null and b/tests/data/pseudo_synapse_dataset/ann_dir/case0005_slice000.png differ diff --git a/tests/data/pseudo_synapse_dataset/ann_dir/case0005_slice001.png b/tests/data/pseudo_synapse_dataset/ann_dir/case0005_slice001.png new file mode 100644 index 0000000000..a22059b58e Binary files /dev/null and b/tests/data/pseudo_synapse_dataset/ann_dir/case0005_slice001.png differ diff --git a/tests/data/pseudo_synapse_dataset/img_dir/case0005_slice000.jpg b/tests/data/pseudo_synapse_dataset/img_dir/case0005_slice000.jpg new file mode 100644 index 0000000000..51609926b4 Binary files /dev/null and b/tests/data/pseudo_synapse_dataset/img_dir/case0005_slice000.jpg differ diff --git a/tests/data/pseudo_synapse_dataset/img_dir/case0005_slice001.jpg b/tests/data/pseudo_synapse_dataset/img_dir/case0005_slice001.jpg new file mode 100644 index 0000000000..e285b8c7f0 Binary files /dev/null and b/tests/data/pseudo_synapse_dataset/img_dir/case0005_slice001.jpg differ diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index b90fc81737..c768f09ade 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -9,7 +9,7 @@ from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset, COCOStuffDataset, DecathlonDataset, ISPRSDataset, LIPDataset, LoveDADataset, PascalVOCDataset, - PotsdamDataset, iSAIDDataset) + PotsdamDataset, SynapseDataset, iSAIDDataset) from mmseg.registry import DATASETS from mmseg.utils import get_classes, get_palette @@ -220,6 +220,19 @@ def test_vaihingen(): assert len(test_dataset) == 1 +def test_synapse(): + test_dataset = SynapseDataset( + pipeline=[], + data_prefix=dict( + img_path=osp.join( + osp.dirname(__file__), + '../data/pseudo_synapse_dataset/img_dir'), + seg_map_path=osp.join( + osp.dirname(__file__), + '../data/pseudo_synapse_dataset/ann_dir'))) + assert len(test_dataset) == 2 + + def test_isaid(): test_dataset = iSAIDDataset( pipeline=[], diff --git a/tests/test_datasets/test_loading.py b/tests/test_datasets/test_loading.py index 29a594b4a2..3d5569682a 100644 --- a/tests/test_datasets/test_loading.py +++ b/tests/test_datasets/test_loading.py @@ -144,6 +144,43 @@ def test_load_seg_custom_classes(self): assert gt_array.dtype == np.uint8 np.testing.assert_array_equal(gt_array, true_mask) + # test with removing a class and reducing zero label simultaneously + results = dict( + img_path=img_path, + seg_map_path=gt_path, + # since reduce_zero_label is True, there are only 4 real classes. + # if the full set of classes is ["A", "B", "C", "D"], the + # following label map simulates the dataset option + # classes=["A", "C", "D"] which removes class "B". + label_map={ + 0: 0, + 1: 255, # simulate removing class 1 + 2: 1, + 3: 2 + }, + reduce_zero_label=True, # reduce zero label + seg_fields=[]) + + load_imgs = LoadImageFromFile() + results = load_imgs(copy.deepcopy(results)) + + # reduce zero label + load_anns = LoadAnnotations() + results = load_anns(copy.deepcopy(results)) + + gt_array = results['gt_seg_map'] + + true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255 + true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0 + true_mask[2:4, 6:8] = 255 # 2s are reduced to class 1 which is removed + true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1 + true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2 + + assert results['seg_fields'] == ['gt_seg_map'] + assert gt_array.shape == (10, 10) + assert gt_array.dtype == np.uint8 + np.testing.assert_array_equal(gt_array, true_mask) + # test no custom classes results = dict( img_path=img_path, diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index 2c18b8e027..906b3c27e8 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -8,7 +8,9 @@ from PIL import Image from mmseg.datasets.transforms import * # noqa -from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop +from mmseg.datasets.transforms import (LoadBiomedicalData, + LoadBiomedicalImageFromFile, + PhotoMetricDistortion, RandomCrop) from mmseg.registry import TRANSFORMS from mmseg.utils import register_all_modules @@ -183,6 +185,68 @@ def test_flip(): assert np.equal(original_seg, results['gt_semantic_seg']).all() +def test_random_rotate_flip(): + with pytest.raises(AssertionError): + transform = dict(type='RandomRotFlip', flip_prob=1.5) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict(type='RandomRotFlip', rotate_prob=1.5) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict(type='RandomRotFlip', degree=[20, 20, 20]) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict(type='RandomRotFlip', degree=-20) + TRANSFORMS.build(transform) + + transform = dict( + type='RandomRotFlip', flip_prob=1.0, rotate_prob=0, degree=20) + rot_flip_module = TRANSFORMS.build(transform) + + results = dict() + img = mmcv.imread( + osp.join( + osp.dirname(__file__), + '../data/pseudo_synapse_dataset/img_dir/case0005_slice000.jpg'), + 'color') + original_img = copy.deepcopy(img) + seg = np.array( + Image.open( + osp.join( + osp.dirname(__file__), + '../data/pseudo_synapse_dataset/ann_dir/case0005_slice000.png') + )) + original_seg = copy.deepcopy(seg) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + result_flip = rot_flip_module(results) + assert original_img.shape == result_flip['img'].shape + assert original_seg.shape == result_flip['gt_semantic_seg'].shape + + transform = dict( + type='RandomRotFlip', flip_prob=0, rotate_prob=1.0, degree=20) + rot_flip_module = TRANSFORMS.build(transform) + + result_rotate = rot_flip_module(results) + assert original_img.shape == result_rotate['img'].shape + assert original_seg.shape == result_rotate['gt_semantic_seg'].shape + + assert str(transform) == "{'type': 'RandomRotFlip'," \ + " 'flip_prob': 0," \ + " 'rotate_prob': 1.0," \ + " 'degree': 20}" + + def test_pad(): # test assertion if both size_divisor and size is None with pytest.raises(AssertionError): @@ -258,7 +322,7 @@ def test_random_crop(): results = pipeline(results) assert results['img'].shape[:2] == (h - 20, w - 20) - assert results['img_shape'][:2] == (h - 20, w - 20) + assert results['img_shape'] == (h - 20, w - 20) assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20) @@ -729,7 +793,7 @@ def test_generate_edge(): results['img_shape'] = seg_map.shape results = transform(results) - assert np.all(results['gt_edge'] == np.array([ + assert np.all(results['gt_edge_map'] == np.array([ [0, 0, 0, 1, 0], [0, 0, 1, 1, 1], [0, 1, 1, 1, 0], @@ -778,3 +842,310 @@ def test_biomedical3d_random_crop(): assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20) assert crop_results['img_shape'] == (d - 20, h - 20, w - 20) assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20) + + +def test_biomedical_gaussian_noise(): + # test assertion for invalid prob + with pytest.raises(AssertionError): + transform = dict(type='BioMedicalGaussianNoise', prob=1.5) + TRANSFORMS.build(transform) + + # test assertion for invalid std + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5) + TRANSFORMS.build(transform) + + transform = dict(type='BioMedicalGaussianNoise', prob=1.0) + noise_module = TRANSFORMS.build(transform) + assert str(noise_module) == 'BioMedicalGaussianNoise'\ + '(prob=1.0, ' \ + 'mean=0.0, ' \ + 'std=0.1)' + + transform = dict(type='BioMedicalGaussianNoise', prob=1.0) + noise_module = TRANSFORMS.build(transform) + results = dict( + img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz')) + from mmseg.datasets.transforms import LoadBiomedicalImageFromFile + transform = LoadBiomedicalImageFromFile() + results = transform(copy.deepcopy(results)) + original_img = copy.deepcopy(results['img']) + results = noise_module(results) + assert original_img.shape == results['img'].shape + + +def test_biomedical_gaussian_blur(): + # test assertion for invalid prob + with pytest.raises(AssertionError): + transform = dict(type='BioMedicalGaussianBlur', prob=-1.5) + TRANSFORMS.build(transform) + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6) + smooth_module = TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6)) + smooth_module = TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16') + TRANSFORMS.build(transform) + + transform = dict( + type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8)) + smooth_module = TRANSFORMS.build(transform) + assert str( + smooth_module + ) == 'BioMedicalGaussianBlur(prob=1.0, ' \ + 'prob_per_channel=0.5, '\ + 'sigma_range=(0.7, 0.8), ' \ + 'different_sigma_per_channel=True, '\ + 'different_sigma_per_axis=True)' + + transform = dict(type='BioMedicalGaussianBlur', prob=1.0) + smooth_module = TRANSFORMS.build(transform) + assert str( + smooth_module + ) == 'BioMedicalGaussianBlur(prob=1.0, ' \ + 'prob_per_channel=0.5, '\ + 'sigma_range=(0.5, 1.0), ' \ + 'different_sigma_per_channel=True, '\ + 'different_sigma_per_axis=True)' + + results = dict( + img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz')) + from mmseg.datasets.transforms import LoadBiomedicalImageFromFile + transform = LoadBiomedicalImageFromFile() + results = transform(copy.deepcopy(results)) + original_img = copy.deepcopy(results['img']) + results = smooth_module(results) + assert original_img.shape == results['img'].shape + # the max value in the smoothed image should be less than the original one + assert original_img.max() >= results['img'].max() + assert original_img.min() <= results['img'].min() + + transform = dict( + type='BioMedicalGaussianBlur', + prob=1.0, + different_sigma_per_axis=False) + smooth_module = TRANSFORMS.build(transform) + + results = dict( + img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz')) + from mmseg.datasets.transforms import LoadBiomedicalImageFromFile + transform = LoadBiomedicalImageFromFile() + results = transform(copy.deepcopy(results)) + original_img = copy.deepcopy(results['img']) + results = smooth_module(results) + assert original_img.shape == results['img'].shape + # the max value in the smoothed image should be less than the original one + assert original_img.max() >= results['img'].max() + assert original_img.min() <= results['img'].min() + + +def test_BioMedicalRandomGamma(): + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 0.2, 0.3)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + invert_image=1) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + per_channel=1) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + retain_stats=1) + TRANSFORMS.build(transform) + + test_img = 'tests/data/biomedical.nii.gz' + results = dict(img_path=test_img) + transform = LoadBiomedicalImageFromFile() + results = transform(copy.deepcopy(results)) + origin_img = results['img'] + transform2 = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + ) + transform2 = TRANSFORMS.build(transform2) + results = transform2(results) + transformed_img = results['img'] + assert origin_img.shape == transformed_img.shape + + +def test_BioMedical3DPad(): + # test assertion. + with pytest.raises(AssertionError): + transform = dict(type='BioMedical3DPad', pad_shape=None) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict(type='BioMedical3DPad', pad_shape=[256, 256]) + TRANSFORMS.build(transform) + + data_info1 = dict(img=np.random.random((8, 6, 4, 4))) + + transform = dict(type='BioMedical3DPad', pad_shape=(6, 6, 6)) + transform = TRANSFORMS.build(transform) + results = transform(copy.deepcopy(data_info1)) + assert results['img'].shape[1:] == (6, 6, 6) + assert results['pad_shape'] == (6, 6, 6) + + transform = dict(type='BioMedical3DPad', pad_shape=(4, 6, 6)) + transform = TRANSFORMS.build(transform) + results = transform(copy.deepcopy(data_info1)) + assert results['img'].shape[1:] == (6, 6, 6) + assert results['pad_shape'] == (6, 6, 6) + + data_info2 = dict( + img=np.random.random((8, 6, 4, 4)), + gt_seg_map=np.random.randint(0, 2, (6, 4, 4))) + + transform = dict(type='BioMedical3DPad', pad_shape=(6, 6, 6)) + transform = TRANSFORMS.build(transform) + results = transform(copy.deepcopy(data_info2)) + assert results['img'].shape[1:] == (6, 6, 6) + assert results['gt_seg_map'].shape[1:] == (6, 6, 6) + assert results['pad_shape'] == (6, 6, 6) + + transform = dict(type='BioMedical3DPad', pad_shape=(4, 6, 6)) + transform = TRANSFORMS.build(transform) + results = transform(copy.deepcopy(data_info2)) + assert results['img'].shape[1:] == (6, 6, 6) + assert results['gt_seg_map'].shape[1:] == (6, 6, 6) + assert results['pad_shape'] == (6, 6, 6) + + +def test_biomedical_3d_flip(): + # test assertion for invalid prob + with pytest.raises(AssertionError): + transform = dict(type='BioMedical3DRandomFlip', prob=1.5, axes=(0, 1)) + transform = TRANSFORMS.build(transform) + + # test assertion for invalid direction + with pytest.raises(AssertionError): + transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(0, 1, 3)) + transform = TRANSFORMS.build(transform) + + # test flip axes are (0, 1, 2) + transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(0, 1, 2)) + transform = TRANSFORMS.build(transform) + + # test with random 3d data + results = dict() + results['img_path'] = 'Null' + results['img_shape'] = (1, 16, 16, 16) + results['img'] = np.random.randn(1, 16, 16, 16) + results['gt_seg_map'] = np.random.randint(0, 4, (16, 16, 16)) + + original_img = results['img'].copy() + original_seg = results['gt_seg_map'].copy() + + # flip first time + results = transform(results) + with pytest.raises(AssertionError): + assert np.equal(original_img, results['img']).all() + with pytest.raises(AssertionError): + assert np.equal(original_seg, results['gt_seg_map']).all() + + # flip second time + results = transform(results) + assert np.equal(original_img, results['img']).all() + assert np.equal(original_seg, results['gt_seg_map']).all() + + # test with actual data and flip axes are (0, 1) + # load biomedical 3d img and seg + data_prefix = osp.join(osp.dirname(__file__), '../data') + input_results = dict(img_path=osp.join(data_prefix, 'biomedical.npy')) + biomedical_loader = LoadBiomedicalData(with_seg=True) + data = biomedical_loader(copy.deepcopy(input_results)) + results = data.copy() + + original_img = data['img'].copy() + original_seg = data['gt_seg_map'].copy() + + # test flip axes are (0, 1) + transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(0, 1)) + transform = TRANSFORMS.build(transform) + + # flip first time + results = transform(results) + with pytest.raises(AssertionError): + assert np.equal(original_img, results['img']).all() + with pytest.raises(AssertionError): + assert np.equal(original_seg, results['gt_seg_map']).all() + + # flip second time + results = transform(results) + assert np.equal(original_img, results['img']).all() + assert np.equal(original_seg, results['gt_seg_map']).all() + + # test transform with flip axes = (1) + transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(1, )) + transform = TRANSFORMS.build(transform) + results = data.copy() + results = transform(results) + results = transform(results) + assert np.equal(original_img, results['img']).all() + assert np.equal(original_seg, results['gt_seg_map']).all() + + # test transform with swap_label_pairs + transform = dict( + type='BioMedical3DRandomFlip', + prob=1, + axes=(1, 2), + swap_label_pairs=[(0, 1)]) + transform = TRANSFORMS.build(transform) + results = data.copy() + results = transform(results) + + with pytest.raises(AssertionError): + assert np.equal(original_seg, results['gt_seg_map']).all() + + # swap twice + results = transform(results) + assert np.equal(original_img, results['img']).all() + assert np.equal(original_seg, results['gt_seg_map']).all() diff --git a/tools/dataset_converters/synapse.py b/tools/dataset_converters/synapse.py new file mode 100644 index 0000000000..42dac6b7ef --- /dev/null +++ b/tools/dataset_converters/synapse.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import nibabel as nib +import numpy as np +from mmengine.utils import mkdir_or_exist +from PIL import Image + + +def read_files_from_txt(txt_path): + with open(txt_path) as f: + files = f.readlines() + files = [file.strip() for file in files] + return files + + +def read_nii_file(nii_path): + img = nib.load(nii_path).get_fdata() + return img + + +def split_3d_image(img): + c, _, _ = img.shape + res = [] + for i in range(c): + res.append(img[i, :, :]) + return res + + +def label_mapping(label): + """Label mapping from TransUNet paper setting. It only has 9 classes, which + are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground + classes in original dataset are all set to background. + + More details could be found here: https://arxiv.org/abs/2102.04306 + """ + maped_label = np.zeros_like(label) + maped_label[label == 8] = 1 + maped_label[label == 4] = 2 + maped_label[label == 3] = 3 + maped_label[label == 2] = 4 + maped_label[label == 6] = 5 + maped_label[label == 11] = 6 + maped_label[label == 1] = 7 + maped_label[label == 7] = 8 + return maped_label + + +def pares_args(): + parser = argparse.ArgumentParser( + description='Convert synapse dataset to mmsegmentation format') + parser.add_argument( + '--dataset-path', type=str, help='synapse dataset path.') + parser.add_argument( + '--save-path', + default='data/synapse', + type=str, + help='save path of the dataset.') + args = parser.parse_args() + return args + + +def main(): + args = pares_args() + dataset_path = args.dataset_path + save_path = args.save_path + + if not osp.exists(dataset_path): + raise ValueError('The dataset path does not exist. ' + 'Please enter a correct dataset path.') + if not osp.exists(osp.join(dataset_path, 'img')) \ + or not osp.exists(osp.join(dataset_path, 'label')): + raise FileNotFoundError('The dataset structure is incorrect. ' + 'Please check your dataset.') + + train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt')) + train_id = [idx[3:7] for idx in train_id] + + test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt')) + test_id = [idx[3:7] for idx in test_id] + + mkdir_or_exist(osp.join(save_path, 'img_dir/train')) + mkdir_or_exist(osp.join(save_path, 'img_dir/val')) + mkdir_or_exist(osp.join(save_path, 'ann_dir/train')) + mkdir_or_exist(osp.join(save_path, 'ann_dir/val')) + + # It follows data preparation pipeline from here: + # https://github.com/Beckschen/TransUNet/tree/main/datasets + for i, idx in enumerate(train_id): + img_3d = read_nii_file( + osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) + label_3d = read_nii_file( + osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) + + img_3d = np.clip(img_3d, -125, 275) + img_3d = (img_3d + 125) / 400 + img_3d *= 255 + img_3d = np.transpose(img_3d, [2, 0, 1]) + img_3d = np.flip(img_3d, 2) + + label_3d = np.transpose(label_3d, [2, 0, 1]) + label_3d = np.flip(label_3d, 2) + label_3d = label_mapping(label_3d) + + for c in range(img_3d.shape[0]): + img = img_3d[c] + label = label_3d[c] + + img = Image.fromarray(img).convert('RGB') + label = Image.fromarray(label).convert('L') + img.save( + osp.join( + save_path, 'img_dir/train', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.jpg')) + label.save( + osp.join( + save_path, 'ann_dir/train', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.png')) + + for i, idx in enumerate(test_id): + img_3d = read_nii_file( + osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) + label_3d = read_nii_file( + osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) + + img_3d = np.clip(img_3d, -125, 275) + img_3d = (img_3d + 125) / 400 + img_3d *= 255 + img_3d = np.transpose(img_3d, [2, 0, 1]) + img_3d = np.flip(img_3d, 2) + + label_3d = np.transpose(label_3d, [2, 0, 1]) + label_3d = np.flip(label_3d, 2) + label_3d = label_mapping(label_3d) + + for c in range(img_3d.shape[0]): + img = img_3d[c] + label = label_3d[c] + + img = Image.fromarray(img).convert('RGB') + label = Image.fromarray(label).convert('L') + img.save( + osp.join( + save_path, 'img_dir/val', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.jpg')) + label.save( + osp.join( + save_path, 'ann_dir/val', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.png')) + + +if __name__ == '__main__': + main()