/
test.py
88 lines (69 loc) · 3.44 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from input_data import *
from model import *
import cv2 as cv
import tensorflow as tf
import os
import numpy as np
def main():
ckpt_state = tf.train.get_checkpoint_state(CHECKPOINT_PATH)
if not ckpt_state or not ckpt_state.model_checkpoint_path:
print('No check point files are found!')
return
ckpt_files = ckpt_state.all_model_checkpoint_paths
num_ckpt = len(ckpt_files)
if num_ckpt < 1:
print('No check point files are found!')
return
low_res_holder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, NUM_CHANNELS])
high_res_holder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, LABEL_SIZE, LABEL_SIZE, NUM_CHANNELS])
inferences = create_model(low_res_holder)
testing_loss = s_mse_loss(inferences, high_res_holder, name='testing_loss')
low_res_batch, high_res_batch = generate_test_queue(TEST_PATH)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
tf.train.start_queue_runners(sess=sess)
cnt = 0
best_mse = 100000
best_ckpt = ''
for ckpt_file in ckpt_files:
mse = 0
for i in range(50):
low_res_images, high_res_images = sess.run([low_res_batch, high_res_batch])
feed_dict = {low_res_holder: low_res_images, high_res_holder: high_res_images}
mse += sess.run(testing_loss, feed_dict=feed_dict)
mse /= 50
print('Model: %s. MSE: %.3f' % (ckpt_file, mse))
if mse < best_mse:
best_mse = mse
best_ckpt = ckpt_file
print('=========================================================')
print('=========================================================')
print('Using models of ' + ckpt_file + ' to generate some patches.')
saver.restore(sess, ckpt_file)
for k in range(4):
low_res_images, high_res_images = sess.run([low_res_batch, high_res_batch])
feed_dict = {low_res_holder: low_res_images, high_res_holder: high_res_images}
inference_patches = sess.run(inferences, feed_dict=feed_dict)
if not os.path.exists(INFERENCE_SAVE_PATH):
os.mkdir(INFERENCE_SAVE_PATH)
for i in range(BATCH_SIZE):
low_res_input = low_res_images[i, ...] # INPUT_SIZE x INPUT_SIZE
ground_truth = high_res_images[i, ...] # LABEL_SIZE x LABEL_SIZE
inference = inference_patches[i, ...]
crop_begin = (ground_truth.shape[0] - inference.shape[0]) // 2
crop_end = crop_begin + inference.shape[0]
ground_truth = ground_truth[crop_begin: crop_end, crop_begin: crop_end, ...]
low_res_input = cv.resize(low_res_input, (LABEL_SIZE, LABEL_SIZE), interpolation=cv.INTER_CUBIC)
low_res_input = low_res_input[crop_begin: crop_end, crop_begin: crop_end, ...]
patch_pair = np.hstack((low_res_input, inference, ground_truth))
# patch_pair += 0.5
patch_pair = tf.image.convert_image_dtype(patch_pair, tf.uint8, True)
save_name = 'inference_%d_%d_%d.png' % (k, i, cnt)
cv.imwrite(join(INFERENCE_SAVE_PATH, save_name), patch_pair.eval(session=sess))
cnt = cnt + 1000
print('Test Finished!')
print('Best model: %s. MSE: %.3f' % (best_ckpt, best_mse))
if __name__ == '__main__':
main()