Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dithering augmentation #1545

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [ChannelShuffle](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ChannelShuffle)
- [ColorJitter](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter)
- [Defocus](https://albumentations.ai/docs/api_reference/augmentations/blur/transforms/#albumentations.augmentations.blur.transforms.Defocus)
- [Dither](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Dither)
- [Downscale](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Downscale)
- [Emboss](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Emboss)
- [Equalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Equalize)
Expand Down
65 changes: 65 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"gray_to_rgb",
"unsharp_mask",
"MAX_VALUES_BY_DTYPE",
"dither",
]

TWO = 2
Expand Down Expand Up @@ -1423,3 +1424,67 @@ def spatter(
raise ValueError("Unsupported spatter mode: " + str(mode))

return img * 255

@clipped
@preserve_shape
def dither(img: np.ndarray, nc: int) -> np.ndarray:
img = img.copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like beetter to use empty:

result = np.empty_like(img)
result[0] = img[0]

It is much faster

height = np.shape(img)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think better to use class method img.shape[0]. I think it's more readable

is_rgb = True if len(np.shape(img)) == 3 else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use is_rgb_image function from albumentations/augmentations/utils


for y in range(height):
oldrow = img[y].copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need clone there

quant_errors = []

if is_rgb:
# Turn into one list of length `width`, per channel (R, G, B)
#
# Use `tolist()` since operating on individual elements of an ndarray
# is very slow compared to a normal list.
channels = np.transpose(oldrow).tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to convert to list? Why do not use np.ndarray channels = oldrow.transpose()? Result would be the same, but faster.


for ch in channels:
ch, qe = _apply_dithering_to_channel(ch, nc)
quant_errors.append(qe)

# Transpose back to one list containing all channels
# and replace the row
img[y] = np.transpose(channels)
quant_errors = np.transpose(quant_errors)
else:
img[y], quant_errors = _apply_dithering_to_channel(img[y].tolist(), nc)

if y < height - 1:
zero_or_zeros = 0 if np.shape(quant_errors[-1]) == () else np.zeros_like(quant_errors[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use np.zeros_like(quant_errors[-1]) this if else is useles

r1 = np.roll(quant_errors, -1, axis=0) * (3 / 16)
r1[-1] = zero_or_zeros
r2 = np.roll(quant_errors, 1, axis=0) / 16
r2[0] = zero_or_zeros
updated_row = r1 + r2 + np.array(quant_errors) * (5 / 16)
img[y + 1] = (img[y + 1] + updated_row).astype(img.dtype)
return img


def _apply_dithering_to_channel(ch, nc):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing

width = len(ch)

# We want to build the quant error list while
# iterating, as otherwise we'll base the error
# on the original value instead of the value it
# got after being affected by error propagation.
quant_error = [0] * width

for x in range(width - 1):
oldval = ch[x]
newval = round(oldval * (nc - 1)) / (nc - 1)
ch[x] = newval
quant_error[x] = oldval - newval
ch[x + 1] += quant_error[x] * (7 / 16)
Comment on lines +1477 to +1482
Copy link
Collaborator

@Dipet Dipet Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's vectorize this

new_val = np.round(quant_error[:-1] * (nc - 1)) * (1 / (nc - 1))
quant_error = ch[:-1] - new_val
ch[:-1] = new_val
ch[1:] += quant_error * (7 / 16)

# Process the last pixel as a separate case (no propagation
# to the right).
oldval = ch[width - 1]
newval = round(oldval * (nc - 1)) / (nc - 1)
ch[width - 1] = newval
quant_error[-1] = oldval - newval

return ch, quant_error
42 changes: 42 additions & 0 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"UnsharpMask",
"PixelDropout",
"Spatter",
"Dither",
]

HUNDRED = 100
Expand Down Expand Up @@ -2793,3 +2794,44 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A

def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]:
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color"

class Dither(ImageOnlyTransform):
"""
Apply dither transform. Dither is an intentionally applied form of noise used to randomize quantization error,
preventing large-scale patterns such as color banding in images.
Args:
nc int: the number of colour choices per channel,
e.g. if nc = 2 we only have 0 and 1, and if nc = 4 we have 0, 0.33, 0.67 and 1 etc
Default value is 2 (the pixel can either be on or off).

always_apply (bool): If `True`, the transform will always be applied, regardless of `p`.
Default is `False`.
p (float): The probability that the transform will be applied. Default is 0.5.
Targets:
image

Image types:
uint8, float32

References :
https://en.wikipedia.org/wiki/Dither
https://en.wikipedia.org/wiki/Floyd%E2%80%93Steinberg_dithering
"""

def __init__(
self,
nc: int = 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nc is not really intuituve name.

num_colors would be better

always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nv is positive, we sjould have a check here

self.nc = nc

def apply(self, img: np.ndarray, nc: int = 2, **params) -> np.ndarray:
return F.dither(img, nc=nc)

def get_params(self):
return {"nc": random.randint(2, 256)}

def get_transform_init_args(self):
return {"nc": self.nc}
32 changes: 32 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,35 @@ def test_brightness_contrast_adjust_equal(beta_by_max):
image_float = (image_float * 255).astype(int)

assert np.abs(image_int.astype(int) - image_float).max() <= 1

@pytest.mark.parametrize("nc", range(2, 17))
def test_dither_nc_palette(nc):
image = np.random.randint(0, 256, [32, 32], dtype=np.uint8)
image = image.astype(np.float32) / 255

image = F.dither(image, nc)

image *= 255
image = image.astype(np.uint8)
res2 = np.unique(image)

assert len(res2) <= nc


def test_dither_grayscale():
image = np.ones([5, 5]) * 127
image = image.astype(np.float32) / 255

image = F.dither(image, nc=2)

image *= 255
image = image.astype(np.uint8)

expected = [
[0, 255, 0, 255, 0],
[255, 0, 255, 0, 255],
[0, 255, 0, 255, 0],
[255, 0, 255, 0, 255],
[0, 255, 0, 255, 0],
]
assert np.array_equal(image, expected)
1 change: 1 addition & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
"fill_value": 0,
},
],
[A.Dither, {"nc": 2}],
]

AUGMENTATION_CLS_EXCEPT = {
Expand Down
11 changes: 11 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,14 @@ def test_spatter_incorrect_color(unsupported_color, mode, message):
A.Spatter(mode=mode, color=unsupported_color)

assert str(exc_info.value).startswith(message)

def test_dither_uint8():
image = np.ndarray(shape=(1, 1), dtype=np.uint8)
image = A.Dither().apply(image)
assert image.dtype == np.uint8


def test_dither_float32():
image = np.ndarray(shape=(1, 1), dtype=np.float32)
image = A.Dither().apply(image)
assert image.dtype == np.float32