Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to train and export TFLite compatible models #157

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 92 additions & 32 deletions pix2pix.py
Expand Up @@ -19,6 +19,7 @@
parser.add_argument("--output_dir", required=True, help="where to put output files")
parser.add_argument("--seed", type=int)
parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing")
parser.add_argument("--export_format", default="tf", choices=["tf", "tflite"])

parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
parser.add_argument("--max_epochs", type=int, help="number of training epochs")
Expand All @@ -43,6 +44,7 @@
parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
parser.add_argument("--norm_type", default="original", choices=["none", "original", "tflite_compatible"])

# export options
parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
Expand Down Expand Up @@ -129,7 +131,12 @@ def lrelu(x, a):


def batchnorm(inputs):
return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
if a.norm_type == "none":
return inputs
if a.batch_size == 1 and a.norm_type != "original":
return tf.contrib.layers.instance_norm(inputs, epsilon=1e-5, param_initializers={'gamma': tf.random_normal_initializer(1.0, 0.02)})
return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=(a.mode == "train" or a.norm_type == "original"),
gamma_initializer=tf.random_normal_initializer(1.0, 0.02))


def check_image(image):
Expand Down Expand Up @@ -385,7 +392,10 @@ def create_generator(generator_inputs, generator_outputs_channels):
input = tf.concat([layers[-1], layers[0]], axis=3)
rectified = tf.nn.relu(input)
output = gen_deconv(rectified, generator_outputs_channels)
output = tf.tanh(output)
if a.export_format == "tflite":
output = 2 * tf.sigmoid(2 * output) - 1
else:
output = tf.tanh(output)
layers.append(output)

return layers[-1]
Expand Down Expand Up @@ -454,14 +464,19 @@ def create_discriminator(discrim_inputs, discrim_targets):
gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
gen_update_ops = [var for var in update_ops if var.name.startswith("generator")]
discrim_update_ops = [var for var in update_ops if var.name.startswith("discriminator")]

with tf.name_scope("discriminator_train"):
discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
with tf.control_dependencies(discrim_update_ops):
discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

with tf.name_scope("generator_train"):
with tf.control_dependencies([discrim_train]):
with tf.control_dependencies(gen_update_ops + [discrim_train]):
gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
Expand Down Expand Up @@ -570,41 +585,51 @@ def main():
if a.lab_colorization:
raise Exception("export not supported for lab_colorization")

input = tf.placeholder(tf.string, shape=[1])
input_data = tf.decode_base64(input[0])
input_image = tf.image.decode_png(input_data)
if a.export_format == "tflite":
input = tf.placeholder(tf.float32, shape=[1,256,256,3], name='TFLiteInput')
input_image = tf.reshape(input, [256,256,3])

else:
input = tf.placeholder(tf.string, shape=[1])
input_data = tf.decode_base64(input[0])
input_image = tf.image.decode_png(input_data)

# remove alpha channel if present
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image)
# convert grayscale to RGB
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image)

# remove alpha channel if present
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image)
# convert grayscale to RGB
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image)
input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32)
input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])

input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32)
input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])
batch_input = tf.expand_dims(input_image, axis=0)

with tf.variable_scope("generator"):
batch_output = deprocess(create_generator(preprocess(batch_input), 3))

output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]
if a.output_filetype == "png":
output_data = tf.image.encode_png(output_image)
elif a.output_filetype == "jpeg":
output_data = tf.image.encode_jpeg(output_image, quality=80)
inputs = {}
outputs = {}

if a.export_format == "tflite":
output = tf.identity(batch_output, 'TFLiteOutput')

else:
raise Exception("invalid filetype")
output = tf.convert_to_tensor([tf.encode_base64(output_data)])
output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]
if a.output_filetype == "png":
output_data = tf.image.encode_png(output_image)
elif a.output_filetype == "jpeg":
output_data = tf.image.encode_jpeg(output_image, quality=80)
else:
raise Exception("invalid filetype")
output = tf.convert_to_tensor([tf.encode_base64(output_data)])

key = tf.placeholder(tf.string, shape=[1])
inputs = {
"key": key.name,
"input": input.name
}
key = tf.placeholder(tf.string, shape=[1])
inputs["key"] = key.name,
outputs["key"] = tf.identity(key).name

inputs["input"] = input.name
outputs["output"] = output.name,
tf.add_to_collection("inputs", json.dumps(inputs))
outputs = {
"key": tf.identity(key).name,
"output": output.name,
}
tf.add_to_collection("outputs", json.dumps(outputs))

init_op = tf.global_variables_initializer()
Expand All @@ -620,6 +645,41 @@ def main():
export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta"))
export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=False)

if a.export_format == "tflite":
from tensorflow.lite.python import lite
from tensorflow.core.framework import graph_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.framework import dtypes
from tensorflow.python.tools.freeze_graph import freeze_graph
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference

print("freezing exported model")
freeze_graph(input_graph="", input_saver="", input_binary=True,
input_checkpoint=os.path.join(a.output_dir, "export"), output_node_names="TFLiteOutput",
restore_op_name="save/restore_all", filename_tensor_name="save/Const:0",
output_graph=os.path.join(a.output_dir, "frozen_graph.pb"), clear_devices=True,
initializer_nodes="", input_meta_graph=os.path.join(a.output_dir, "export.meta"))

print("optimizing frozen graph")
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(os.path.join(a.output_dir, "frozen_graph.pb"), "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)

output_graph_def = optimize_for_inference(input_graph_def, ["TFLiteInput"], ["TFLiteOutput"],
dtypes.float32.as_datatype_enum, True)
with gfile.FastGFile(os.path.join(a.output_dir, "optimized_graph.pb"), "w") as f:
f.write(output_graph_def.SerializeToString())

print("converting optimized graph to tflite")
converter = lite.TFLiteConverter.from_frozen_graph(input_arrays=["TFLiteInput"],
output_arrays=["TFLiteOutput"],
graph_def_file=os.path.join(a.output_dir, "optimized_graph.pb"))
converter.post_training_quantize = True
output_data = converter.convert()
with open(os.path.join(a.output_dir, "converted_model.tflite"), "wb") as f:
f.write(output_data)

return

examples = load_examples()
Expand Down
76 changes: 76 additions & 0 deletions tools/query-export.py
@@ -0,0 +1,76 @@
import os
import argparse
from pprint import pprint

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--export", default=None, help="output_dir from --mode export run")
parser.add_argument("--frozen", default=None, help="frozen graph pb file")
parser.add_argument("--tflite", default=None, help="tflite file")
parser.add_argument("--input", default='Untitled.png', help="input image")
parser.add_argument("--output", default='out.png', help="output image")
a = parser.parse_args()

out_files = [ a.output ]
im_files = [ a.input ]
images_cv = [ cv2.resize(cv2.imread(f), (256, 256)) for f in im_files ]
images = np.array(images_cv, dtype=np.float32)
images = images / 255.0

if a.tflite:
interpreter = tf.lite.Interpreter(model_path=a.tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

interpreter.set_tensor(input_details[0]['index'], images)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
output = output[:,:,:,::-1]
output = output * 255

print("Writing " + out_files[0])
cv2.imwrite(out_files[0], output[0]);

if a.frozen:
with tf.Session() as sess:
with gfile.FastGFile(a.frozen, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')

input = tf.get_default_graph().get_tensor_by_name("TFLiteInput:0")
output = tf.get_default_graph().get_tensor_by_name("TFLiteOutput:0")
output = output[:,:,:,::-1]
output = output * 255

print("Writing " + out_files[0])
cv2.imwrite(out_files[0], output.eval({'TFLiteInput:0': images})[0]);

if a.export:
with tf.Session() as sess:
with gfile.FastGFile(os.path.join(a.export, 'export.meta'), 'rb') as f:
meta_graph_def = tf.MetaGraphDef()
meta_graph_def.ParseFromString(f.read())
tf.train.import_meta_graph(meta_graph_def)
checkpoint = tf.train.latest_checkpoint(a.export)
restore_saver = tf.train.Saver()
restore_saver.restore(sess, checkpoint)

input = tf.get_default_graph().get_tensor_by_name("TFLiteInput:0")
output = tf.get_default_graph().get_tensor_by_name("TFLiteOutput:0")
output = output[:,:,:,::-1]
output = output * 255

print("Writing " + out_files[0])
cv2.imwrite(out_files[0], output.eval({'TFLiteInput:0': images})[0]);

main()