Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add RescaleIntensity & ZNormalization & ClampIntensity transforms #2241

Open
wants to merge 2 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions mmseg/datasets/transforms/__init__.py
Expand Up @@ -3,15 +3,17 @@
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray)
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
from .transforms import (CLAHE, AdjustGamma, ClampIntensity, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
RGB2Gray, SegRescale)
RandomMosaic, RandomRotate, Rerange, RescaleIntensity,
ResizeToMultiple, RGB2Gray, SegRescale,
ZNormalization)

__all__ = [
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'RescaleIntensity', 'ZNormalization', 'ClampIntensity'
]
198 changes: 197 additions & 1 deletion mmseg/datasets/transforms/transforms.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Sequence, Tuple, Union
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union

import cv2
import mmcv
Expand Down Expand Up @@ -1226,3 +1226,199 @@ def __repr__(self):
repr_str += f'edge_width={self.edge_width}, '
repr_str += f'ignore_index={self.ignore_index})'
return repr_str


@TRANSFORMS.register_module()
class RescaleIntensity(BaseTransform):
"""Rescale intensity.

# This class is modified from `MONAI.

# https://github.com/Project-MONAI/MONAI/blob/dev/monai/transforms/intensity/array.py#L739
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License")

Required Keys:

- img

Modified Keys:

- img

Args:
in_min (float, optional): minimum intensity of source image.
Defaults to None.
in_max (float, optional): maximum intensity of source image.
Defaults to None.
out_min (float, optional): minimum intensity of target image.
Defaults to None.
out_max (float, optional): maximum intensity of target image.
Defaults to None.
"""

def __init__(self,
in_min: Optional[float] = None,
in_max: Optional[float] = None,
out_min: Optional[float] = None,
out_max: Optional[float] = None):
self.in_min = in_min
self.in_max = in_max
self.out_min = out_min
self.out_max = out_max

def rescale_intensity(self, img):
if self.in_min is None:
self.in_min = np.min(img)
if self.in_max is None:
self.in_max = np.max(img)

if self.in_max - self.in_min == 0.0:
if self.out_min is None:
return img - self.in_min
return img - self.in_min + self.out_min

img = (img - self.in_min) / (self.in_max - self.in_min)
if (self.out_min is not None) and (self.out_max is not None):
img = img * (self.out_max - self.out_min) + self.out_min

return img

def transform(self, results: dict) -> dict:
img = results['img']
img = self.rescale_intensity(img)
results['img'] = img
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(in_min={self.in_min}, '
repr_str += f'in_max={self.in_max}, '
repr_str += f'out_min={self.out_min}, '
repr_str += f'out_max={self.out_max})'
return repr_str


@TRANSFORMS.register_module()
class ZNormalization(BaseTransform):
"""z_normalization.

# This class is modified from `MONAI.
# https://github.com/Project-MONAI/MONAI/blob/dev/monai/transforms/intensity/array.py#L605
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License")

Required Keys:

- img

Modified Keys:

- img

Args:
mean (float, optional): the mean to subtract by
Defaults to None.
std (float, optional): the standard deviation to divide by
Defaults to None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstring about channel_wise .

"""

def __init__(self,
mean: Optional[Union[float, Iterable[float]]] = None,
std: Optional[Union[float, Iterable[float]]] = None,
channel_wise: bool = False) -> None:
self.mean = mean
self.std = std
self.channel_wise = channel_wise

def _normalize(self, img: np.ndarray, mean=None, std=None):
slices = np.ones_like(img, dtype=bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does slices = np.ones_like(img, dtype=bool) for?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that the "slices" is designed for the case that we only want to normalize the non-zero intensity of the image in MONAI. But I haven't added this function. In this version, the "slices" is not necessary. I will remove it. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that the "slices" is designed for the case that we only want to normalize the non-zero intensity of the image in MONAI. But I haven't added this function. In this version, the "slices" is not necessary. I will remove it. Thanks!

OK, got it! Thanks. Shall we add this normalize the non-zero intensity of the image function in the future? Do you think this function is important in medical image segmentation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experiences so far, the difference of using it is slight most of the time but it's meaningful.
Adding it would be a better idea. Thanks!

if not slices.any():
return img

_mean = mean if mean is not None else np.mean(img[slices])
_std = std if std is not None else np.std(img[slices])

if np.isscalar(_std):
if _std == 0.0:
_std = 1.0
else:
_std = _std[slices]
_std[_std == 0.0] = 1.0

img[slices] = (img[slices] - _mean) / _std
return img

def znorm(self, img):
if self.channel_wise:
if self.mean is not None and len(self.mean) != len(img):
err_str = (f'img has {len(img)} channels, '
f'but mean has {len(self.mean)}.')
raise ValueError(err_str)
if self.std is not None and len(self.std) != len(img):
err_str = (f'img has {len(img)} channels, '
f'but std has {len(self.std)}.')
raise ValueError(err_str)

for i, d in enumerate(img):
img[i] = self._normalize(
d,
mean=self.mean[i] if self.mean is not None else None,
std=self.std[i] if self.std is not None else None,
)
else:
img = self._normalize(img, self.mean, self.std)

return img

def transform(self, results: dict) -> dict:
img = results['img']
img = self.znorm(img)
results['img'] = img
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, '
repr_str += f'std={self.std}, '
repr_str += f'channel_wise={self.channel_wise})'
return repr_str


@TRANSFORMS.register_module()
class ClampIntensity(BaseTransform):
"""clamp intensity.

Required Keys:

- img

Modified Keys:

- img

Args:
min (float, optional): Minimum target intensity
max (float, optional): Maximum target intensity
"""

def __init__(self,
t_min: Optional[float] = None,
t_max: Optional[float] = None) -> None:
self.t_min = t_min
self.t_max = t_max

def clamp(self, img):
return np.clip(img, self.t_min, self.t_max)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can delete L1411-1412 since np.clip() could be directly used in def transform()?


def transform(self, results: dict) -> dict:
img = results['img']
img = self.clamp(img)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img = self.clamp(img)
img = np.clip(img, self.t_min, self.t_max)

results['img'] = img
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(t_min={self.t_min}, '
repr_str += f't_max={self.t_max})'
return repr_str
99 changes: 98 additions & 1 deletion tests/test_datasets/test_transform.py
Expand Up @@ -8,7 +8,8 @@
from PIL import Image

from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile,
PhotoMetricDistortion, RandomCrop)
from mmseg.registry import TRANSFORMS


Expand Down Expand Up @@ -706,3 +707,99 @@ def test_generate_edge():
[1, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
]))


def test_rescale_intensity():
results = dict(
img_path=osp.join(
osp.join(osp.dirname(__file__), '../data'), 'biomedical.nii.gz'))
data_transform = LoadBiomedicalImageFromFile()
results = data_transform(copy.deepcopy(results))

transform = dict(
type='RescaleIntensity', in_min=20, in_max=108, out_min=50, out_max=80)
transform = TRANSFORMS.build(transform)
rescale_results = transform(copy.deepcopy(results))
assert np.allclose(rescale_results['img'],
(((results['img'] - 20) / 88) * 30 + 50))

transform = dict(
type='RescaleIntensity',
in_min=None,
in_max=None,
out_min=50,
out_max=80)
transform = TRANSFORMS.build(transform)
rescale_results = transform(copy.deepcopy(results))
in_min = np.min(results['img'])
in_max = np.max(results['img'])
assert np.allclose(rescale_results['img'], (((results['img'] - in_min) /
(in_max - in_min)) * 30 + 50))

transform = dict(
type='RescaleIntensity',
in_min=108,
in_max=108,
out_min=None,
out_max=80)
transform = TRANSFORMS.build(transform)
rescale_results = transform(copy.deepcopy(results))
assert np.allclose(rescale_results['img'], (results['img'] - 108))

transform = dict(
type='RescaleIntensity',
in_min=108,
in_max=108,
out_min=50,
out_max=80)
transform = TRANSFORMS.build(transform)
rescale_results = transform(copy.deepcopy(results))
assert np.allclose(rescale_results['img'], (results['img'] - 58))


def test_z_normalization():
results = dict(
img_path=osp.join(
osp.join(osp.dirname(__file__), '../data'), 'biomedical.nii.gz'))
data_transform = LoadBiomedicalImageFromFile()
results = data_transform(copy.deepcopy(results))

transform = dict(type='ZNormalization')
transform = TRANSFORMS.build(transform)
norm_results = transform(copy.deepcopy(results))
assert np.allclose(norm_results['img'],
(results['img'] - np.mean(results['img'])) /
np.std(results['img']))

results['img'] = np.random.randint(-1000, 1000,
(2, 96, 96, 96)).astype(np.float32)
transform = dict(type='ZNormalization', channel_wise=True)
transform = TRANSFORMS.build(transform)
norm_results = transform(copy.deepcopy(results))
img_ = copy.deepcopy(results['img'])
img_[0] = (img_[0] - np.mean(img_[0])) / np.std(img_[0])
img_[1] = (img_[1] - np.mean(img_[1])) / np.std(img_[1])
assert np.allclose(norm_results['img'], img_)

transform = dict(type='ZNormalization', channel_wise=True, mean=[1])
transform = TRANSFORMS.build(transform)
with pytest.raises(ValueError):
transform(copy.deepcopy(results))

transform = dict(
type='ZNormalization', channel_wise=True, mean=[1, 2], std=[1])
transform = TRANSFORMS.build(transform)
with pytest.raises(ValueError):
transform(copy.deepcopy(results))


def test_clamp_intensity():
results = dict(
img_path=osp.join(
osp.join(osp.dirname(__file__), '../data'), 'biomedical.nii.gz'))
data_transform = LoadBiomedicalImageFromFile()
results = data_transform(copy.deepcopy(results))
transform = dict(type='ClampIntensity', t_min=50, t_max=80)
transform = TRANSFORMS.build(transform)
clamp_results = transform(copy.deepcopy(results))
assert np.allclose(clamp_results['img'], np.clip(results['img'], 50, 80))