-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dithering augmentation #1545
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ | |
"gray_to_rgb", | ||
"unsharp_mask", | ||
"MAX_VALUES_BY_DTYPE", | ||
"dither", | ||
] | ||
|
||
TWO = 2 | ||
|
@@ -1423,3 +1424,67 @@ def spatter( | |
raise ValueError("Unsupported spatter mode: " + str(mode)) | ||
|
||
return img * 255 | ||
|
||
@clipped | ||
@preserve_shape | ||
def dither(img: np.ndarray, nc: int) -> np.ndarray: | ||
img = img.copy() | ||
height = np.shape(img)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think better to use class method |
||
is_rgb = True if len(np.shape(img)) == 3 else False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to use is_rgb_image function from albumentations/augmentations/utils |
||
|
||
for y in range(height): | ||
oldrow = img[y].copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need clone there |
||
quant_errors = [] | ||
|
||
if is_rgb: | ||
# Turn into one list of length `width`, per channel (R, G, B) | ||
# | ||
# Use `tolist()` since operating on individual elements of an ndarray | ||
# is very slow compared to a normal list. | ||
channels = np.transpose(oldrow).tolist() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need to convert to list? Why do not use np.ndarray |
||
|
||
for ch in channels: | ||
ch, qe = _apply_dithering_to_channel(ch, nc) | ||
quant_errors.append(qe) | ||
|
||
# Transpose back to one list containing all channels | ||
# and replace the row | ||
img[y] = np.transpose(channels) | ||
quant_errors = np.transpose(quant_errors) | ||
else: | ||
img[y], quant_errors = _apply_dithering_to_channel(img[y].tolist(), nc) | ||
|
||
if y < height - 1: | ||
zero_or_zeros = 0 if np.shape(quant_errors[-1]) == () else np.zeros_like(quant_errors[-1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just use |
||
r1 = np.roll(quant_errors, -1, axis=0) * (3 / 16) | ||
r1[-1] = zero_or_zeros | ||
r2 = np.roll(quant_errors, 1, axis=0) / 16 | ||
r2[0] = zero_or_zeros | ||
updated_row = r1 + r2 + np.array(quant_errors) * (5 / 16) | ||
img[y + 1] = (img[y + 1] + updated_row).astype(img.dtype) | ||
return img | ||
|
||
|
||
def _apply_dithering_to_channel(ch, nc): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typing |
||
width = len(ch) | ||
|
||
# We want to build the quant error list while | ||
# iterating, as otherwise we'll base the error | ||
# on the original value instead of the value it | ||
# got after being affected by error propagation. | ||
quant_error = [0] * width | ||
|
||
for x in range(width - 1): | ||
oldval = ch[x] | ||
newval = round(oldval * (nc - 1)) / (nc - 1) | ||
ch[x] = newval | ||
quant_error[x] = oldval - newval | ||
ch[x + 1] += quant_error[x] * (7 / 16) | ||
Comment on lines
+1477
to
+1482
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's vectorize this new_val = np.round(quant_error[:-1] * (nc - 1)) * (1 / (nc - 1))
quant_error = ch[:-1] - new_val
ch[:-1] = new_val
ch[1:] += quant_error * (7 / 16) |
||
# Process the last pixel as a separate case (no propagation | ||
# to the right). | ||
oldval = ch[width - 1] | ||
newval = round(oldval * (nc - 1)) / (nc - 1) | ||
ch[width - 1] = newval | ||
quant_error[-1] = oldval - newval | ||
|
||
return ch, quant_error |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,6 +74,7 @@ | |
"UnsharpMask", | ||
"PixelDropout", | ||
"Spatter", | ||
"Dither", | ||
] | ||
|
||
HUNDRED = 100 | ||
|
@@ -2793,3 +2794,44 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A | |
|
||
def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]: | ||
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color" | ||
|
||
class Dither(ImageOnlyTransform): | ||
""" | ||
Apply dither transform. Dither is an intentionally applied form of noise used to randomize quantization error, | ||
preventing large-scale patterns such as color banding in images. | ||
Args: | ||
nc int: the number of colour choices per channel, | ||
e.g. if nc = 2 we only have 0 and 1, and if nc = 4 we have 0, 0.33, 0.67 and 1 etc | ||
Default value is 2 (the pixel can either be on or off). | ||
|
||
always_apply (bool): If `True`, the transform will always be applied, regardless of `p`. | ||
Default is `False`. | ||
p (float): The probability that the transform will be applied. Default is 0.5. | ||
Targets: | ||
image | ||
|
||
Image types: | ||
uint8, float32 | ||
|
||
References : | ||
https://en.wikipedia.org/wiki/Dither | ||
https://en.wikipedia.org/wiki/Floyd%E2%80%93Steinberg_dithering | ||
""" | ||
|
||
def __init__( | ||
self, | ||
nc: int = 2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nc is not really intuituve name.
|
||
always_apply: bool = False, | ||
p: float = 0.5, | ||
): | ||
super().__init__(always_apply, p) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nv is positive, we sjould have a check here |
||
self.nc = nc | ||
|
||
def apply(self, img: np.ndarray, nc: int = 2, **params) -> np.ndarray: | ||
return F.dither(img, nc=nc) | ||
|
||
def get_params(self): | ||
return {"nc": random.randint(2, 256)} | ||
|
||
def get_transform_init_args(self): | ||
return {"nc": self.nc} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like beetter to use empty:
It is much faster