Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scripts adjusted to run under newer versions of tensorflow #201

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 53 additions & 51 deletions pix2pix.py
Expand Up @@ -13,6 +13,8 @@
import math
import time

tf.compat.v1.disable_v2_behavior()

parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", help="path to folder containing images")
parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
Expand Down Expand Up @@ -93,27 +95,27 @@ def augment(image, brightness):

def discrim_conv(batch_input, out_channels, stride):
padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02))
return tf.compat.v1.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02))


def gen_conv(batch_input, out_channels):
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
initializer = tf.random_normal_initializer(0, 0.02)
if a.separable_conv:
return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
return tf.compat.v1.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
else:
return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
return tf.compat.v1.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)


def gen_deconv(batch_input, out_channels):
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
initializer = tf.random_normal_initializer(0, 0.02)
if a.separable_conv:
_b, h, w, _c = batch_input.shape
resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
resized_input = tf.image.resize(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return tf.compat.v1.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
else:
return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
return tf.compat.v1.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)


def lrelu(x, a):
Expand All @@ -129,7 +131,7 @@ def lrelu(x, a):


def batchnorm(inputs):
return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
return tf.compat.v1.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))


def check_image(image):
Expand Down Expand Up @@ -255,8 +257,8 @@ def get_name(path):
input_paths = sorted(input_paths)

with tf.name_scope("load_images"):
path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
reader = tf.WholeFileReader()
path_queue = tf.compat.v1.train.string_input_producer(input_paths, shuffle=a.mode == "train")
reader = tf.compat.v1.WholeFileReader()
paths, contents = reader.read(path_queue)
raw_input = decode(contents)
raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
Expand Down Expand Up @@ -296,9 +298,9 @@ def transform(image):

# area produces a nice downscaling, but does nearest neighbor for upscaling
# assume we're going to be doing downscaling here
r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
r = tf.image.resize(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)

offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
offset = tf.cast(tf.floor(tf.compat.v1.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
if a.scale_size > CROP_SIZE:
r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
elif a.scale_size < CROP_SIZE:
Expand All @@ -311,7 +313,7 @@ def transform(image):
with tf.name_scope("target_images"):
target_images = transform(targets)

paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
paths_batch, inputs_batch, targets_batch = tf.compat.v1.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))

return Examples(
Expand All @@ -327,7 +329,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
layers = []

# encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
with tf.variable_scope("encoder_1"):
with tf.compat.v1.variable_scope("encoder_1"):
output = gen_conv(generator_inputs, a.ngf)
layers.append(output)

Expand All @@ -342,7 +344,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
]

for out_channels in layer_specs:
with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
with tf.compat.v1.variable_scope("encoder_%d" % (len(layers) + 1)):
rectified = lrelu(layers[-1], 0.2)
# [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
convolved = gen_conv(rectified, out_channels)
Expand All @@ -362,7 +364,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
num_encoder_layers = len(layers)
for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
skip_layer = num_encoder_layers - decoder_layer - 1
with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
with tf.compat.v1.variable_scope("decoder_%d" % (skip_layer + 1)):
if decoder_layer == 0:
# first decoder layer doesn't have skip connections
# since it is directly connected to the skip_layer
Expand All @@ -376,12 +378,12 @@ def create_generator(generator_inputs, generator_outputs_channels):
output = batchnorm(output)

if dropout > 0.0:
output = tf.nn.dropout(output, keep_prob=1 - dropout)
output = tf.compat.v1.nn.dropout(output, keep_prob=1 - dropout)

layers.append(output)

# decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
with tf.variable_scope("decoder_1"):
with tf.compat.v1.variable_scope("decoder_1"):
input = tf.concat([layers[-1], layers[0]], axis=3)
rectified = tf.nn.relu(input)
output = gen_deconv(rectified, generator_outputs_channels)
Expand All @@ -400,7 +402,7 @@ def create_discriminator(discrim_inputs, discrim_targets):
input = tf.concat([discrim_inputs, discrim_targets], axis=3)

# layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
with tf.variable_scope("layer_1"):
with tf.compat.v1.variable_scope("layer_1"):
convolved = discrim_conv(input, a.ndf, stride=2)
rectified = lrelu(convolved, 0.2)
layers.append(rectified)
Expand All @@ -409,7 +411,7 @@ def create_discriminator(discrim_inputs, discrim_targets):
# layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
# layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
for i in range(n_layers):
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
with tf.compat.v1.variable_scope("layer_%d" % (len(layers) + 1)):
out_channels = a.ndf * min(2**(i+1), 8)
stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
convolved = discrim_conv(layers[-1], out_channels, stride=stride)
Expand All @@ -418,60 +420,60 @@ def create_discriminator(discrim_inputs, discrim_targets):
layers.append(rectified)

# layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
with tf.compat.v1.variable_scope("layer_%d" % (len(layers) + 1)):
convolved = discrim_conv(rectified, out_channels=1, stride=1)
output = tf.sigmoid(convolved)
layers.append(output)

return layers[-1]

with tf.variable_scope("generator"):
with tf.compat.v1.variable_scope("generator"):
out_channels = int(targets.get_shape()[-1])
outputs = create_generator(inputs, out_channels)

# create two copies of discriminator, one for real pairs and one for fake pairs
# they share the same underlying variables
with tf.name_scope("real_discriminator"):
with tf.variable_scope("discriminator"):
with tf.compat.v1.variable_scope("discriminator"):
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
predict_real = create_discriminator(inputs, targets)

with tf.name_scope("fake_discriminator"):
with tf.variable_scope("discriminator", reuse=True):
with tf.compat.v1.variable_scope("discriminator", reuse=True):
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
predict_fake = create_discriminator(inputs, outputs)

with tf.name_scope("discriminator_loss"):
# minimizing -tf.log will try to get inputs to 1
# minimizing -tf.compat.v1.log will try to get inputs to 1
# predict_real => 1
# predict_fake => 0
discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
discrim_loss = tf.reduce_mean(-(tf.compat.v1.log(predict_real + EPS) + tf.compat.v1.log(1 - predict_fake + EPS)))

with tf.name_scope("generator_loss"):
# predict_fake => 1
# abs(targets - outputs) => 0
gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
gen_loss_GAN = tf.reduce_mean(-tf.compat.v1.log(predict_fake + EPS))
gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight

with tf.name_scope("discriminator_train"):
discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
discrim_tvars = [var for var in tf.compat.v1.trainable_variables() if var.name.startswith("discriminator")]
discrim_optim = tf.compat.v1.train.AdamOptimizer(a.lr, a.beta1)
discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

with tf.name_scope("generator_train"):
with tf.control_dependencies([discrim_train]):
gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
gen_tvars = [var for var in tf.compat.v1.trainable_variables() if var.name.startswith("generator")]
gen_optim = tf.compat.v1.train.AdamOptimizer(a.lr, a.beta1)
gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
gen_train = gen_optim.apply_gradients(gen_grads_and_vars)

ema = tf.train.ExponentialMovingAverage(decay=0.99)
update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])

global_step = tf.train.get_or_create_global_step()
incr_global_step = tf.assign(global_step, global_step+1)
global_step = tf.compat.v1.train.get_or_create_global_step()
incr_global_step = tf.compat.v1.assign(global_step, global_step+1)

return Model(
predict_real=predict_real,
Expand Down Expand Up @@ -537,7 +539,7 @@ def main():
if a.seed is None:
a.seed = random.randint(0, 2**31 - 1)

tf.set_random_seed(a.seed)
tf.compat.v1.set_random_seed(a.seed)
np.random.seed(a.seed)
random.seed(a.seed)

Expand Down Expand Up @@ -570,8 +572,8 @@ def main():
if a.lab_colorization:
raise Exception("export not supported for lab_colorization")

input = tf.placeholder(tf.string, shape=[1])
input_data = tf.decode_base64(input[0])
input = tf.compat.v1.placeholder(tf.string, shape=[1])
input_data = tf.compat.v1.decode_base64(input[0])
input_image = tf.image.decode_png(input_data)

# remove alpha channel if present
Expand All @@ -583,7 +585,7 @@ def main():
input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])
batch_input = tf.expand_dims(input_image, axis=0)

with tf.variable_scope("generator"):
with tf.compat.v1.variable_scope("generator"):
batch_output = deprocess(create_generator(preprocess(batch_input), 3))

output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]
Expand All @@ -593,25 +595,25 @@ def main():
output_data = tf.image.encode_jpeg(output_image, quality=80)
else:
raise Exception("invalid filetype")
output = tf.convert_to_tensor([tf.encode_base64(output_data)])
output = tf.convert_to_tensor([tf.compat.v1.encode_base64(output_data)])

key = tf.placeholder(tf.string, shape=[1])
key = tf.compat.v1.placeholder(tf.string, shape=[1])
inputs = {
"key": key.name,
"input": input.name
}
tf.add_to_collection("inputs", json.dumps(inputs))
tf.compat.v1.add_to_collection("inputs", json.dumps(inputs))
outputs = {
"key": tf.identity(key).name,
"output": output.name,
}
tf.add_to_collection("outputs", json.dumps(outputs))
tf.compat.v1.add_to_collection("outputs", json.dumps(outputs))

init_op = tf.global_variables_initializer()
restore_saver = tf.train.Saver()
export_saver = tf.train.Saver()
init_op = tf.compat.v1.global_variables_initializer()
restore_saver = tf.compat.v1.train.Saver()
export_saver = tf.compat.v1.train.Saver()

with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
sess.run(init_op)
print("loading model from checkpoint")
checkpoint = tf.train.latest_checkpoint(a.checkpoint)
Expand Down Expand Up @@ -654,7 +656,7 @@ def convert(image):
if a.aspect_ratio != 1.0:
# upscale to correct aspect ratio
size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))]
image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC)
image = tf.image.resize(image, size=size, method=tf.image.ResizeMethod.BICUBIC)

return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)

Expand Down Expand Up @@ -696,19 +698,19 @@ def convert(image):
tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)

for var in tf.trainable_variables():
for var in tf.compat.v1.trainable_variables():
tf.summary.histogram(var.op.name + "/values", var)

for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
tf.summary.histogram(var.op.name + "/gradients", grad)

with tf.name_scope("parameter_count"):
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.compat.v1.trainable_variables()])

saver = tf.train.Saver(max_to_keep=1)
saver = tf.compat.v1.train.Saver(max_to_keep=1)

logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
sv = tf.compat.v1.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
with sv.managed_session() as sess:
print("parameter_count =", sess.run(parameter_count))

Expand Down Expand Up @@ -747,8 +749,8 @@ def should(freq):
options = None
run_metadata = None
if should(a.trace_freq):
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
run_metadata = tf.compat.v1.RunMetadata()

fetches = {
"train": model.train,
Expand Down
10 changes: 6 additions & 4 deletions tools/process.py
Expand Up @@ -14,6 +14,8 @@
import time
import multiprocessing

tf.compat.v1.disable_v2_behavior()

edge_pool = None


Expand Down Expand Up @@ -265,12 +267,12 @@ def main():
edge_pool = multiprocessing.Pool(a.workers)

if a.workers == 1:
with tf.Session() as sess:
with tf.compat.v1.Session() as sess:
for src_path, dst_path in zip(src_paths, dst_paths):
process(src_path, dst_path)
complete()
else:
queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
queue = tf.compat.v1.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
dequeue_op = queue.dequeue()

def worker(coord):
Expand All @@ -286,8 +288,8 @@ def worker(coord):
complete()

# init epoch counter for the queue
local_init_op = tf.local_variables_initializer()
with tf.Session() as sess:
local_init_op = tf.compat.v1.local_variables_initializer()
with tf.compat.v1.Session() as sess:
sess.run(local_init_op)

coord = tf.train.Coordinator()
Expand Down