Skip to content

Commit

Permalink
Merge pull request #885 from danforthcenter/color-correction-other-dt…
Browse files Browse the repository at this point in the history
…ypes

support for other dtypes in color correction
  • Loading branch information
nfahlgren committed Apr 14, 2022
2 parents 12957d9 + 03dd635 commit c5f957a
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions plantcv/plantcv/transform/color_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ def get_color_matrix(rgb_img, mask):
if len(np.shape(mask)) != 2:
fatal_error("Input mask is not an gray-scale image.")

img_dtype = rgb_img.dtype
# normalization value as max number if the type is unsigned int
max_val = 1.0
if img_dtype.kind == 'u':
max_val = np.iinfo(img_dtype).max

# convert to float and normalize to work with values between 0-1
rgb_img = rgb_img.astype(np.float64)/255
rgb_img = rgb_img.astype(np.float64)/max_val

# create empty color_matrix
color_matrix = np.zeros((len(np.unique(mask))-1, 4))
Expand Down Expand Up @@ -203,8 +209,13 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
# split transformation_matrix
red, green, blue, red2, green2, blue2, red3, green3, blue3 = np.split(transformation_matrix, 9, 1)

source_dtype = source_img.dtype
# normalization value as max number if the type is unsigned int
max_val = 1.0
if source_dtype.kind == 'u':
max_val = np.iinfo(source_dtype).max
# convert img to float to avoid integer overflow, normalize between 0-1
source_flt = source_img.astype(np.float64)/255
source_flt = source_img.astype(np.float64)/max_val
# find linear, square, and cubic values of source_img color channels
source_b, source_g, source_r = cv2.split(source_flt)
source_b2 = np.square(source_b)
Expand All @@ -226,17 +237,16 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
bgr = [b, g, r]
corrected_img = cv2.merge(bgr)

# return values of the image to the 0-255 range
corrected_img = 255*np.clip(corrected_img, 0, 1)
corrected_img = np.floor(corrected_img)
# cast back to unsigned int
corrected_img = corrected_img.astype(np.uint8)
# return values of the image to the original range
corrected_img = max_val*np.clip(corrected_img, 0, 1)
# cast back to original dtype (if uint the value defaults to the closest smaller integer)
corrected_img = corrected_img.astype(source_dtype)

# For debugging, create a horizontal view of source_img, corrected_img, and target_img to the plotting device
# plot horizontal comparison of source_img, corrected_img (with rounded elements) and target_img
# cast source_img back to unsigned int between 0-255 for visualization
source_flt = (255*source_flt).astype(np.uint8)
out_img = np.hstack([source_img, corrected_img, target_img])
# Change range of visualization image to 0-255 and convert to uin8
out_img = ((255.0/max_val)*out_img).astype(np.uint8)
_debug(visual=out_img, filename=os.path.join(params.debug_outdir, str(params.device) + '_corrected.png'))

# return corrected_img
Expand Down Expand Up @@ -416,7 +426,7 @@ def quick_color_check(target_matrix, source_matrix, num_chips):
scale_y_continuous, scale_color_manual, aes
import pandas as pd

# Scale matrices back to 0-255
# Scale matrices to 0-255
target_matrix = 255*target_matrix
source_matrix = 255*source_matrix

Expand Down

0 comments on commit c5f957a

Please sign in to comment.