-
Notifications
You must be signed in to change notification settings - Fork 10
/
colorTransferCV2.py
127 lines (102 loc) · 3.08 KB
/
colorTransferCV2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Normalize a patch stain to the target image using the method of:
E. Reinhard, M. Adhikhmin, B. Gooch, and P. Shirley, ‘Color transfer between images’, IEEE Computer Graphics and Applications, vol. 21, no. 5, pp. 34–41, Sep. 2001.
"""
from __future__ import division
import cv2 as cv
import numpy as np
### Some functions ###
def lab_split(I):
"""
Convert from RGB uint8 to LAB and split into channels
:param I: uint8
:return:
"""
I = cv.cvtColor(I, cv.COLOR_RGB2LAB)
I = I.astype(np.float32)
I1, I2, I3 = cv.split(I)
I1 /= 2.55
I2 -= 128.0
I3 -= 128.0
return I1, I2, I3
def merge_back(I1, I2, I3):
"""
Take seperate LAB channels and merge back to give RGB uint8
:param I1:
:param I2:
:param I3:
:return:
"""
I1 *= 2.55
I2 += 128.0
I3 += 128.0
I = np.clip(cv.merge((I1, I2, I3)), 0, 255).astype(np.uint8)
return cv.cvtColor(I, cv.COLOR_LAB2RGB)
def fix_L_img(I):
"""
Takes only the intensity image and fixes it according
to function merge_back to give RGB uint8.
"""
I *= 2.55
I = np.clip(I, 0, 255).astype(np.uint8)
return I[...,None]
def get_mean_std(I):
"""
Get mean and standard deviation of each channel
:param I: uint8
:return:
"""
I1, I2, I3 = lab_split(I)
m1, sd1 = cv.meanStdDev(I1)
m2, sd2 = cv.meanStdDev(I2)
m3, sd3 = cv.meanStdDev(I3)
means = m1, m2, m3
stds = sd1, sd2, sd3
return means, stds
def standardize_brightness(I):
"""
:param I:
:return:
"""
p = np.percentile(I, 90)
return np.clip(I * 255.0 / p, 0, 255).astype(np.uint8)
### Main class ###
class StainNormalizerLAB(object):
"""
A stain normalization object
"""
def __init__(self):
self.target_means = None
self.target_stds = None
def fit(self, target):
target = standardize_brightness(target)
means, stds = get_mean_std(target)
self.target_means = means
self.target_stds = stds
def __call__(self, I):
I = standardize_brightness(I)
I1, I2, I3 = lab_split(I)
means, stds = get_mean_std(I)
norm1 = ((I1 - means[0]) * (self.target_stds[0] / stds[0])) + self.target_means[0]
norm2 = ((I2 - means[1]) * (self.target_stds[1] / stds[1])) + self.target_means[1]
norm3 = ((I3 - means[2]) * (self.target_stds[2] / stds[2])) + self.target_means[2]
return np.array(merge_back(norm1, norm2, norm3))
class StainNormalizerL(object):
"""
A stain normalization object
"""
def __init__(self):
self.target_means = None
self.target_stds = None
def fit(self, target):
target = standardize_brightness(target)
means, stds = get_mean_std(target)
self.target_means = means
self.target_stds = stds
def __call__(self, I):
I = standardize_brightness(I)
I1, I2, I3 = lab_split(I)
means, stds = get_mean_std(I)
norm1 = ((I1 - means[0]) * (self.target_stds[0] / stds[0])) + self.target_means[0]
norm1 = fix_L_img(norm1)
return norm1