From ceee992a6e0d114923a397020ac5f100269f3d2f Mon Sep 17 00:00:00 2001 From: Vadim Markovtsev Date: Tue, 7 Mar 2017 13:53:08 +0100 Subject: [PATCH] Swivel: add multiple GPU support This speeds up the training accordingly. --- swivel/swivel.py | 269 +++++++++++++++++++++++++++++------------------ 1 file changed, 166 insertions(+), 103 deletions(-) diff --git a/swivel/swivel.py b/swivel/swivel.py index 48b67a8d3dedfe..45311834c094f8 100755 --- a/swivel/swivel.py +++ b/swivel/swivel.py @@ -52,7 +52,6 @@ """ from __future__ import print_function -import argparse import glob import math import os @@ -62,6 +61,7 @@ import numpy as np import tensorflow as tf +from tensorflow.python.client import device_lib flags = tf.app.flags @@ -85,13 +85,26 @@ flags.DEFINE_float('learning_rate', 1.0, 'Initial learning rate') flags.DEFINE_integer('num_concurrent_steps', 2, 'Number of threads to train with') +flags.DEFINE_integer('num_readers', 4, + 'Number of threads to read the input data and feed it') flags.DEFINE_float('num_epochs', 40, 'Number epochs to train for') -flags.DEFINE_float('per_process_gpu_memory_fraction', 0.25, - 'Fraction of GPU memory to use') +flags.DEFINE_float('per_process_gpu_memory_fraction', 0, + 'Fraction of GPU memory to use, 0 means allow_growth') +flags.DEFINE_integer('num_gpus', 0, + 'Number of GPUs to use, 0 means all available') FLAGS = flags.FLAGS +def log(message, *args, **kwargs): + tf.logging.info(message, *args, **kwargs) + + +def get_available_gpus(): + return [d.name for d in device_lib.list_local_devices() + if d.device_type == 'GPU'] + + def embeddings_with_init(vocab_size, embedding_dim, name): """Creates and initializes the embedding tensors.""" return tf.get_variable(name=name, @@ -130,7 +143,7 @@ def count_matrix_input(filenames, submatrix_rows, submatrix_cols): queued_global_row, queued_global_col, queued_count = tf.train.batch( [global_row, global_col, count], batch_size=1, - num_threads=4, + num_threads=FLAGS.num_readers, capacity=32) queued_global_row = tf.reshape(queued_global_row, [submatrix_rows]) @@ -164,16 +177,14 @@ def write_embeddings_to_disk(config, model, sess): # Row Embedding row_vocab_path = config.input_base_path + '/row_vocab.txt' row_embedding_output_path = config.output_base_path + '/row_embedding.tsv' - print('Writing row embeddings to:', row_embedding_output_path) - sys.stdout.flush() + log('Writing row embeddings to: %s', row_embedding_output_path) write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path, sess, model.row_embedding) # Column Embedding col_vocab_path = config.input_base_path + '/col_vocab.txt' col_embedding_output_path = config.output_base_path + '/col_embedding.tsv' - print('Writing column embeddings to:', col_embedding_output_path) - sys.stdout.flush() + log('Writing column embeddings to: %s', col_embedding_output_path) write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path, sess, model.col_embedding) @@ -186,8 +197,7 @@ def __init__(self, config): self._config = config # Create paths to input data files - print('Reading model from:', config.input_base_path) - sys.stdout.flush() + log('Reading model from: %s', config.input_base_path) count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb') row_sums_path = config.input_base_path + '/row_sums.txt' col_sums_path = config.input_base_path + '/col_sums.txt' @@ -198,93 +208,129 @@ def __init__(self, config): self.n_rows = len(row_sums) self.n_cols = len(col_sums) - print('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % ( - self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)) - sys.stdout.flush() + log('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d)', + self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols) self.n_submatrices = (self.n_rows * self.n_cols / (config.submatrix_rows * config.submatrix_cols)) - print('n_submatrices: %d' % (self.n_submatrices)) - sys.stdout.flush() - - # ===== CREATE VARIABLES ====== - # embeddings - self.row_embedding = embeddings_with_init( - embedding_dim=config.embedding_size, - vocab_size=self.n_rows, - name='row_embedding') - self.col_embedding = embeddings_with_init( - embedding_dim=config.embedding_size, - vocab_size=self.n_cols, - name='col_embedding') - tf.summary.histogram('row_emb', self.row_embedding) - tf.summary.histogram('col_emb', self.col_embedding) - - matrix_log_sum = math.log(np.sum(row_sums) + 1) - row_bias_init = [math.log(x + 1) for x in row_sums] - col_bias_init = [math.log(x + 1) for x in col_sums] - self.row_bias = tf.Variable( - row_bias_init, trainable=config.trainable_bias) - self.col_bias = tf.Variable( - col_bias_init, trainable=config.trainable_bias) - tf.summary.histogram('row_bias', self.row_bias) - tf.summary.histogram('col_bias', self.col_bias) - - # ===== CREATE GRAPH ===== - - # Get input - global_row, global_col, count = count_matrix_input( - count_matrix_files, config.submatrix_rows, config.submatrix_cols) - - # Fetch embeddings. - selected_row_embedding = tf.nn.embedding_lookup( - self.row_embedding, global_row) - selected_col_embedding = tf.nn.embedding_lookup( - self.col_embedding, global_col) - - # Fetch biases. - selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row) - selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col) - - # Multiply the row and column embeddings to generate predictions. - predictions = tf.matmul( - selected_row_embedding, selected_col_embedding, transpose_b=True) - - # These binary masks separate zero from non-zero values. - count_is_nonzero = tf.to_float(tf.cast(count, tf.bool)) - count_is_zero = 1 - tf.to_float(tf.cast(count, tf.bool)) - - objectives = count_is_nonzero * tf.log(count + 1e-30) - objectives -= tf.reshape(selected_row_bias, [config.submatrix_rows, 1]) - objectives -= selected_col_bias - objectives += matrix_log_sum - - err = predictions - objectives - - # The confidence function scales the L2 loss based on the raw co-occurrence - # count. - l2_confidence = (config.confidence_base + config.confidence_scale * tf.pow( - count, config.confidence_exponent)) - - l2_loss = config.loss_multiplier * tf.reduce_sum( - 0.5 * l2_confidence * err * err * count_is_nonzero) - - sigmoid_loss = config.loss_multiplier * tf.reduce_sum( - tf.nn.softplus(err) * count_is_zero) - - self.loss = l2_loss + sigmoid_loss - - tf.summary.scalar("l2_loss", l2_loss) - tf.summary.scalar("sigmoid_loss", sigmoid_loss) - tf.summary.scalar("loss", self.loss) - - # Add optimizer. - self.global_step = tf.Variable(0, name='global_step') - opt = tf.train.AdagradOptimizer(config.learning_rate) - self.train_op = opt.minimize(self.loss, global_step=self.global_step) - self.saver = tf.train.Saver(sharded=True) + log('n_submatrices: %d', self.n_submatrices) + + with tf.device('/cpu:0'): + # ===== CREATE VARIABLES ====== + # Get input + global_row, global_col, count = count_matrix_input( + count_matrix_files, config.submatrix_rows, config.submatrix_cols) + + # Embeddings + self.row_embedding = embeddings_with_init( + embedding_dim=config.embedding_size, + vocab_size=self.n_rows, + name='row_embedding') + self.col_embedding = embeddings_with_init( + embedding_dim=config.embedding_size, + vocab_size=self.n_cols, + name='col_embedding') + tf.summary.histogram('row_emb', self.row_embedding) + tf.summary.histogram('col_emb', self.col_embedding) + + matrix_log_sum = math.log(np.sum(row_sums) + 1) + row_bias_init = [math.log(x + 1) for x in row_sums] + col_bias_init = [math.log(x + 1) for x in col_sums] + self.row_bias = tf.Variable( + row_bias_init, trainable=config.trainable_bias) + self.col_bias = tf.Variable( + col_bias_init, trainable=config.trainable_bias) + tf.summary.histogram('row_bias', self.row_bias) + tf.summary.histogram('col_bias', self.col_bias) + + # Add optimizer + l2_losses = [] + sigmoid_losses = [] + self.global_step = tf.Variable(0, name='global_step') + opt = tf.train.AdagradOptimizer(config.learning_rate) + + all_grads = [] + + devices = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)] \ + if FLAGS.num_gpus > 0 else get_available_gpus() + self.devices_number = len(devices) + with tf.variable_scope(tf.get_variable_scope()): + for dev in devices: + with tf.device(dev): + with tf.name_scope(dev[1:].replace(':', '_')): + # ===== CREATE GRAPH ===== + # Fetch embeddings. + selected_row_embedding = tf.nn.embedding_lookup( + self.row_embedding, global_row) + selected_col_embedding = tf.nn.embedding_lookup( + self.col_embedding, global_col) + + # Fetch biases. + selected_row_bias = tf.nn.embedding_lookup( + [self.row_bias], global_row) + selected_col_bias = tf.nn.embedding_lookup( + [self.col_bias], global_col) + + # Multiply the row and column embeddings to generate predictions. + predictions = tf.matmul( + selected_row_embedding, selected_col_embedding, + transpose_b=True) + + # These binary masks separate zero from non-zero values. + count_is_nonzero = tf.to_float(tf.cast(count, tf.bool)) + count_is_zero = 1 - count_is_nonzero + + objectives = count_is_nonzero * tf.log(count + 1e-30) + objectives -= tf.reshape( + selected_row_bias, [config.submatrix_rows, 1]) + objectives -= selected_col_bias + objectives += matrix_log_sum + + err = predictions - objectives + + # The confidence function scales the L2 loss based on the raw + # co-occurrence count. + l2_confidence = (config.confidence_base + + config.confidence_scale * tf.pow( + count, config.confidence_exponent)) + + l2_loss = config.loss_multiplier * tf.reduce_sum( + 0.5 * l2_confidence * err * err * count_is_nonzero) + l2_losses.append(tf.expand_dims(l2_loss, 0)) + + sigmoid_loss = config.loss_multiplier * tf.reduce_sum( + tf.nn.softplus(err) * count_is_zero) + sigmoid_losses.append(tf.expand_dims(sigmoid_loss, 0)) + + loss = l2_loss + sigmoid_loss + grads = opt.compute_gradients(loss) + all_grads.append(grads) + + with tf.device('/cpu:0'): + # ===== MERGE LOSSES ===== + l2_loss = tf.reduce_mean(tf.concat(l2_losses, 0), 0, name="l2_loss") + sigmoid_loss = tf.reduce_mean(tf.concat(sigmoid_losses, 0), 0, + name="sigmoid_loss") + self.loss = l2_loss + sigmoid_loss + average = tf.train.ExponentialMovingAverage(0.8, self.global_step) + loss_average_op = average.apply((self.loss,)) + tf.summary.scalar("l2_loss", l2_loss) + tf.summary.scalar("sigmoid_loss", sigmoid_loss) + tf.summary.scalar("loss", self.loss) + + # Apply the gradients to adjust the shared variables. + apply_gradient_ops = [] + for grads in all_grads: + apply_gradient_ops.append(opt.apply_gradients( + grads, global_step=self.global_step)) + + self.train_op = tf.group(loss_average_op, *apply_gradient_ops) + self.saver = tf.train.Saver(sharded=True) def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + start_time = time.time() + # Create the output path. If this fails, it really ought to fail # now. :) if not os.path.isdir(FLAGS.output_base_path): @@ -295,8 +341,13 @@ def main(_): model = SwivelModel(FLAGS) # Create a session for running Ops on the Graph. - gpu_options = tf.GPUOptions( - per_process_gpu_memory_fraction=FLAGS.per_process_gpu_memory_fraction) + gpu_opts = {} + if FLAGS.per_process_gpu_memory_fraction > 0: + gpu_opts["per_process_gpu_memory_fraction"] = \ + FLAGS.per_process_gpu_memory_fraction + else: + gpu_opts["allow_growth"] = True + gpu_options = tf.GPUOptions(**gpu_opts) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) # Run the Op to initialize the variables. @@ -309,21 +360,32 @@ def main(_): # Calculate how many steps each thread should run n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / ( FLAGS.submatrix_rows * FLAGS.submatrix_cols) - n_steps_per_thread = n_total_steps / FLAGS.num_concurrent_steps + n_steps_per_thread = n_total_steps / ( + FLAGS.num_concurrent_steps * model.devices_number) n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs t0 = [time.time()] + n_steps_between_status_updates = 100 + status_i = [0] + status_lock = threading.Lock() + msg = ('%%%dd/%%d submatrices trained (%%.1f%%%%), %%5.1f submatrices/sec |' + ' loss %%f') % len(str(n_submatrices_to_train)) def TrainingFn(): for _ in range(int(n_steps_per_thread)): - _, global_step = sess.run([model.train_op, model.global_step]) - n_steps_between_status_updates = 100 - if (global_step % n_steps_between_status_updates) == 0: + _, global_step, loss = sess.run(( + model.train_op, model.global_step, model.loss)) + + show_status = False + with status_lock: + new_i = global_step // n_steps_between_status_updates + if new_i > status_i[0]: + status_i[0] = new_i + show_status = True + if show_status: elapsed = float(time.time() - t0[0]) - print('%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % ( - global_step, n_submatrices_to_train, + log(msg, global_step, n_submatrices_to_train, 100.0 * global_step / n_submatrices_to_train, - n_steps_between_status_updates / elapsed)) - sys.stdout.flush() + n_steps_between_status_updates / elapsed, loss) t0[0] = time.time() # Start training threads @@ -343,8 +405,9 @@ def TrainingFn(): # Write out vectors write_embeddings_to_disk(FLAGS, model, sess) - #Shutdown + # Shutdown sess.close() + log("Elapsed: %s", time.time() - start_time) if __name__ == '__main__':