Skip to content

Commit

Permalink
[Feature] add HSI-Drive dataset (#3365)
Browse files Browse the repository at this point in the history
## Motivation

The motivation is to add a hyperspectral dataset [HSI Drive
2.0](https://ipaccess.ehu.eus/HSI-Drive/) to the dataset registry which
would be, as far as I know, the first hyperspectral database of
mmsegmentation. This database has been presented in [HSI-Drive v2.0:
More Data for New Challenges in Scene Understanding for Autonomous
Driving](https://ieeexplore.ieee.org/document/10371793) and the initival
v1 was presented in [HSI-Drive: A Dataset for the Research of
Hyperspectral Image Processing Applied to Autonomous Driving
Systems](https://ieeexplore.ieee.org/document/9575298)

## Modification

I have created/modified the following aspects:
- READMEs: `README.md` and `README_zh-CN.md` (sorry if translation is
not accurate).
- Example project: `projects/hsidrive20_dataset` has been created and
filled for users to know how to work with this database.
- Documentation: `docs/en/user_guides/2_dataset_prepare.md` and
`docs/zh_cn/user_guides/2_dataset_prepare.md` (sorry if translation is
not accurate) have been updated for users to know how to download and
configure the dataset.
- Database related files: `mmseg/datasets/__init__.py`,
`mmseg/datasets/hsi_drive.py` and `configs/_base_/datasets/hsi_drive.py`
where the dataset is described and also prepared for
training/validation/test.
- Transforms related files:
`mmsegmentation/mmseg/datasets/transforms/loading.py` to *include
support for loading images from .npy files* such as the hyperspectral
images of this dataset.
- Training config with well-known neural network:
`configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py` for people
to train a standard neural network with this dataset.
- Tests: added necessary files under
`tests/data/pseudo_hsidrive20_dataset`.

**Important:** I have also modified `.pre-commit-config.yaml` to ignore
HSI error in codespell.

## BC-breaking (Optional)

No.

## Use cases (Optional)

A train example has been added under `projects/hsidrive20_dataset` and
documentation has been updated as it is explained in Modification
section.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.

Regarding 1. I don't know how to solve this problem. Could you help me,
please? This causes 2 checks not to be successful.

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
  • Loading branch information
jonGuti13 and xiexinch committed Jan 10, 2024
1 parent 6a709be commit 7a392ad
Show file tree
Hide file tree
Showing 23 changed files with 576 additions and 2 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Expand Up @@ -37,6 +37,7 @@ repos:
rev: v2.2.1
hooks:
- id: codespell
args: [--ignore-words-list=hsi]
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
Expand Down
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -339,6 +339,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#levir-cd">LEVIR-CD</a></li>
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#bdd100K">BDD100K</a></li>
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nyu">NYU</a></li>
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0">HSIDrive20</a></li>
</ul>
</td>
<td>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Expand Up @@ -328,6 +328,7 @@ MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#levir-cd">LEVIR-CD</a></li>
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#bdd100K">BDD100K</a></li>
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nyu">NYU</a></li>
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0">HSIDrive20</a></li>
</ul>
</td>
<td>
Expand Down
53 changes: 53 additions & 0 deletions configs/_base_/datasets/hsi_drive.py
@@ -0,0 +1,53 @@
train_pipeline = [
dict(type='LoadImageFromNpyFile'),
dict(type='LoadAnnotations'),
dict(type='RandomCrop', crop_size=(192, 384)),
dict(type='PackSegInputs')
]

test_pipeline = [
dict(type='LoadImageFromNpyFile'),
dict(type='RandomCrop', crop_size=(192, 384)),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]

train_dataloader = dict(
batch_size=4,
num_workers=1,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type='HSIDrive20Dataset',
data_root='data/HSIDrive20',
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='HSIDrive20Dataset',
data_root='data/HSIDrive20',
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))

test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='HSIDrive20Dataset',
data_root='data/HSIDrive20',
data_prefix=dict(
img_path='images/test', seg_map_path='annotations/test'),
pipeline=test_pipeline))

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0)
test_evaluator = val_evaluator
36 changes: 36 additions & 0 deletions configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
@@ -0,0 +1,36 @@
_base_ = [
'../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/hsi_drive.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (192, 384)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
mean=None,
std=None,
bgr_to_rgb=None,
pad_val=0,
seg_pad_val=255)

model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(in_channels=25),
decode_head=dict(
ignore_index=0,
num_classes=11,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
avg_non_ignore=True)),
auxiliary_head=dict(
ignore_index=0,
num_classes=11,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
avg_non_ignore=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
52 changes: 52 additions & 0 deletions docs/en/user_guides/2_dataset_prepare.md
Expand Up @@ -205,6 +205,15 @@ mmsegmentation
│ │ ├── annotations
│ │ │ ├── train
│ │ │ ├── test
│ ├── HSIDrive20
│ │ ├── images
│ │ │ ├── train
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── annotations
│ │ │ ├── train
│ │ │ ├── validation
│ │ │ ├── test
```

## Download dataset via MIM
Expand Down Expand Up @@ -752,3 +761,46 @@ mmsegmentation
```bash
python tools/dataset_converters/nyu.py nyu.zip
```

## HSI Drive 2.0

- You could download HSI Drive 2.0 dataset from [here](https://ipaccess.ehu.eus/HSI-Drive/#download) after just sending an email to gded@ehu.eus with the subject "download HSI-Drive". You will receive a password to uncompress the files.

- After download, unzip by the following instructions:

```bash
7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip

mv ./HSIDrive20 path_to_mmsegmentation/data
mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
mv ./image_numbering.pdf path_to_mmsegmentation/data
```

- After unzip, you get

```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── HSIDrive20
│ │ ├── images
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── annotations
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── images_MF
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── RGB
│ │ ├── training_filenames.txt
│ │ ├── validation_filenames.txt
│ │ ├── test_filenames.txt
│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
│ ├── image_numbering.pdf
```
52 changes: 52 additions & 0 deletions docs/zh_cn/user_guides/2_dataset_prepare.md
Expand Up @@ -205,6 +205,15 @@ mmsegmentation
│ │ ├── annotations
│ │ │ ├── train
│ │ │ ├── test
│ ├── HSIDrive20
│ │ ├── images
│ │ │ ├── train
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── annotations
│ │ │ ├── train
│ │ │ ├── validation
│ │ │ ├── test
```

## 用 MIM 下载数据集
Expand Down Expand Up @@ -748,3 +757,46 @@ mmsegmentation
```bash
python tools/dataset_converters/nyu.py nyu.zip
```

## HSI Drive 2.0

- 您可以从以下位置下载 HSI Drive 2.0 数据集 [here](https://ipaccess.ehu.eus/HSI-Drive/#download) 刚刚向 gded@ehu.eus 发送主题为“下载 HSI-Drive”的电子邮件后 您将收到解压缩文件的密码.

- 下载后,按照以下说明解压:

```bash
7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip

mv ./HSIDrive20 path_to_mmsegmentation/data
mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
mv ./image_numbering.pdf path_to_mmsegmentation/data
```

- 解压后得到:

```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── HSIDrive20
│ │ ├── images
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── annotations
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── images_MF
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ │ ├── RGB
│ │ ├── training_filenames.txt
│ │ ├── validation_filenames.txt
│ │ ├── test_filenames.txt
│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
│ ├── image_numbering.pdf
```
3 changes: 2 additions & 1 deletion mmseg/datasets/__init__.py
Expand Up @@ -12,6 +12,7 @@
from .drive import DRIVEDataset
from .dsdl import DSDLSegDataset
from .hrf import HRFDataset
from .hsi_drive import HSIDrive20Dataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .levir import LEVIRCDDataset
Expand Down Expand Up @@ -60,5 +61,5 @@
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
'NYUDataset'
'NYUDataset', 'HSIDrive20Dataset'
]
42 changes: 42 additions & 0 deletions mmseg/datasets/hsi_drive.py
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets import BaseSegDataset
from mmseg.registry import DATASETS

classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation',
'painted metal', 'sky', 'concrete', 'pedestrian', 'water',
'unpainted metal', 'glass')
palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0],
[255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0],
[0, 207, 250], [255, 166, 0], [0, 204, 204]]


@DATASETS.register_module()
class HSIDrive20Dataset(BaseSegDataset):
"""HSI-Drive v2.0 (https://ieeexplore.ieee.org/document/10371793), the
updated version of HSI-Drive
(https://ieeexplore.ieee.org/document/9575298), is a structured dataset for
the research and development of automated driving systems (ADS) supported
by hyperspectral imaging (HSI). It contains per-pixel manually annotated
images selected from videos recorded in real driving conditions and has
been organized according to four parameters: season, daytime, road type,
and weather conditions.
The video sequences have been captured with a small-size 25-band VNIR
(Visible-NearlnfraRed) snapshot hyperspectral camera mounted on a driving
automobile. As a consequence, you need to modify the in_channels parameter
of your model from 3 (RGB images) to 25 (HSI images) as it is done in
configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
Apart from the abovementioned articles, additional information is provided
in the website (https://ipaccess.ehu.eus/HSI-Drive/) from where you can
download the dataset and also visualize some examples of segmented videos.
"""

METAINFO = dict(classes=classes_exp, palette=palette_exp)

def __init__(self,
img_suffix='.npy',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
67 changes: 67 additions & 0 deletions mmseg/datasets/transforms/loading.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from pathlib import Path
from typing import Dict, Optional, Union

import mmcv
Expand Down Expand Up @@ -702,3 +703,69 @@ def __repr__(self):
f'to_float32={self.to_float32}, '
f'backend_args={self.backend_args})')
return repr_str


@TRANSFORMS.register_module()
class LoadImageFromNpyFile(LoadImageFromFile):
"""Load an image from ``results['img_path']``.
Required Keys:
- img_path
Modified Keys:
- img
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""

def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image.
Args:
results (dict): Result dict from
:class:`mmengine.dataset.BaseDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""

filename = results['img_path']

try:
if Path(filename).suffix in ['.npy', '.npz']:
img = np.load(filename)
else:
if self.file_client_args is not None:
file_client = fileio.FileClient.infer_client(
self.file_client_args, filename)
img_bytes = file_client.get(filename)
else:
img_bytes = fileio.get(
filename, backend_args=self.backend_args)
img = mmcv.imfrombytes(
img_bytes,
flag=self.color_type,
backend=self.imdecode_backend)
except Exception as e:
if self.ignore_empty:
return None
else:
raise e

# in some cases, images are not read successfully, the img would be
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
assert img is not None, f'failed to load image: {filename}'
if self.to_float32:
img = img.astype(np.float32)

results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results

0 comments on commit 7a392ad

Please sign in to comment.