Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply tensorflow 2.0's v1 compatibility layer so the code works with TF2. #615

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions server/bert_serving/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ def get_estimator(self, tf):
from tensorflow.python.estimator.model_fn import EstimatorSpec

def model_fn(features, labels, mode, params):
with tf.gfile.GFile(self.graph_path, 'rb') as f:
graph_def = tf.GraphDef()
with tf.io.gfile.GFile(self.graph_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())

input_names = ['input_ids', 'input_mask', 'input_type_ids']
Expand All @@ -506,7 +506,7 @@ def model_fn(features, labels, mode, params):
'encodes': output[0]
})

config = tf.ConfigProto(device_count={'GPU': 0 if self.device_id < 0 else 1})
config = tf.compat.v1.ConfigProto(device_count={'GPU': 0 if self.device_id < 0 else 1})
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction
config.log_device_placement = False
Expand Down
53 changes: 26 additions & 27 deletions server/bert_serving/server/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def from_dict(cls, json_object):
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.gfile.GFile(json_file, "r") as reader:
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))

Expand Down Expand Up @@ -119,7 +119,7 @@ class BertModel(object):
model = modeling.BertModel(config=config, is_training=True,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)

label_embeddings = tf.get_variable(...)
label_embeddings = tf.compat.v1.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
Expand Down Expand Up @@ -169,8 +169,8 @@ def __init__(self,
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("embeddings"):
with tf.compat.v1.variable_scope(scope, default_name="bert"):
with tf.compat.v1.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
input_ids=input_ids,
Expand All @@ -194,7 +194,7 @@ def __init__(self,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)

with tf.variable_scope("encoder"):
with tf.compat.v1.variable_scope("encoder"):
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
Expand Down Expand Up @@ -222,12 +222,12 @@ def __init__(self,
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with tf.variable_scope("pooler"):
with tf.compat.v1.variable_scope("pooler"):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
# https://github.com/google-research/bert/issues/43#issuecomment-435980269
self.pooled_output = tf.layers.dense(
self.pooled_output = tf.compat.v1.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
Expand Down Expand Up @@ -275,7 +275,7 @@ def gelu(input_tensor):
Returns:
`input_tensor` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
cdf = 0.5 * (1.0 + tf.math.erf(input_tensor / tf.sqrt(2.0)))
return input_tensor * cdf


Expand Down Expand Up @@ -363,8 +363,7 @@ def dropout(input_tensor, dropout_prob):

def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
return tf.keras.layers.LayerNormalization(axis=-1)(input_tensor)


def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
Expand All @@ -376,7 +375,7 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):

def create_initializer(initializer_range=0.02):
"""Creates a `truncated_normal_initializer` with the given range."""
return tf.truncated_normal_initializer(stddev=initializer_range)
return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range)


def embedding_lookup(input_ids,
Expand Down Expand Up @@ -409,7 +408,7 @@ def embedding_lookup(input_ids,
if input_ids.shape.ndims == 2:
input_ids = tf.expand_dims(input_ids, axis=[-1])

embedding_table = tf.get_variable(
embedding_table = tf.compat.v1.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -482,7 +481,7 @@ def embedding_postprocessor(input_tensor,
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
token_type_table = tf.compat.v1.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
Expand All @@ -496,7 +495,7 @@ def embedding_postprocessor(input_tensor,
output += token_type_embeddings

if use_position_embeddings:
full_position_embeddings = tf.get_variable(
full_position_embeddings = tf.compat.v1.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -675,23 +674,23 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
to_tensor_2d = reshape_to_matrix(to_tensor)

# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(
query_layer = tf.compat.v1.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
kernel_initializer=create_initializer(initializer_range))

# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
key_layer = tf.compat.v1.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
kernel_initializer=create_initializer(initializer_range))

# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
value_layer = tf.compat.v1.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
Expand Down Expand Up @@ -836,12 +835,12 @@ def transformer_model(input_tensor,

all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
layer_input = prev_output

with tf.variable_scope("attention"):
with tf.compat.v1.variable_scope("attention"):
attention_heads = []
with tf.variable_scope("self"):
with tf.compat.v1.variable_scope("self"):
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
Expand All @@ -866,25 +865,25 @@ def transformer_model(input_tensor,

# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
with tf.compat.v1.variable_scope("output"):
attention_output = tf.compat.v1.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)

# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
with tf.compat.v1.variable_scope("intermediate"):
intermediate_output = tf.compat.v1.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))

# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
with tf.compat.v1.variable_scope("output"):
layer_output = tf.compat.v1.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -991,7 +990,7 @@ def assert_rank(tensor, expected_rank, name=None):

actual_rank = tensor.shape.ndims
if actual_rank not in expected_rank_dict:
scope_name = tf.get_variable_scope().name
scope_name = tf.compat.v1.get_variable_scope().name
raise ValueError(
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
Expand Down
8 changes: 4 additions & 4 deletions server/bert_serving/server/bert/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

if use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)

tvars = tf.trainable_variables()
tvars = tf.compat.v1.trainable_variables()
grads = tf.gradients(loss, tvars)

# This is how the model was pre-trained.
Expand Down Expand Up @@ -110,13 +110,13 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):

param_name = self._get_variable_name(param.name)

m = tf.get_variable(
m = tf.compat.v1.get_variable(
name=param_name + "/adam_m",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
v = tf.get_variable(
v = tf.compat.v1.get_variable(
name=param_name + "/adam_v",
shape=param.shape.as_list(),
dtype=tf.float32,
Expand Down
2 changes: 1 addition & 1 deletion server/bert_serving/server/bert/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.gfile.GFile(vocab_file, "r") as reader:
with tf.io.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
Expand Down
4 changes: 4 additions & 0 deletions server/bert_serving/server/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


def main():
from bert_serving.server import BertServer
from bert_serving.server.helper import get_run_args
Expand Down
38 changes: 19 additions & 19 deletions server/bert_serving/server/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def optimize_graph(args, logger=None):
tf = import_tf(verbose=args.verbose)
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference

config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)
config = tf.compat.v1.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)

config_fp = os.path.join(args.model_dir, args.config_name)
init_checkpoint = os.path.join(args.tuned_model_dir or args.model_dir, args.ckpt_name)
Expand All @@ -56,16 +56,16 @@ def optimize_graph(args, logger=None):
logger.info(
'checkpoint%s: %s' % (
' (override by the fine-tuned model)' if args.tuned_model_dir else '', init_checkpoint))
with tf.gfile.GFile(config_fp, 'r') as f:
with tf.io.gfile.GFile(config_fp, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f))

logger.info('build graph...')
# input placeholders, not sure if they are friendly to XLA
input_ids = tf.placeholder(tf.int32, (None, None), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, None), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, None), 'input_type_ids')
input_ids = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_ids')
input_mask = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_mask')
input_type_ids = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_type_ids')

jit_scope = tf.contrib.compiler.jit.experimental_jit_scope if args.xla else contextlib.suppress
jit_scope = tf.xla.experimental.jit_scope if args.xla else contextlib.suppress

with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
Expand All @@ -81,36 +81,36 @@ def optimize_graph(args, logger=None):

if args.pooling_strategy == PoolingStrategy.CLASSIFICATION:
hidden_size = model.pooled_output.shape[-1].value
output_weights = tf.get_variable(
output_weights = tf.compat.v1.get_variable(
'output_weights', [args.num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02))

output_bias = tf.get_variable(
output_bias = tf.compat.v1.get_variable(
'output_bias', [args.num_labels], initializer=tf.zeros_initializer())

if args.pooling_strategy == PoolingStrategy.REGRESSION:
hidden_size = model.pooled_output.shape[-1].value
output_weights = tf.get_variable(
output_weights = tf.compat.v1.get_variable(
'output_weights', [1, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02))

output_bias = tf.get_variable(
output_bias = tf.compat.v1.get_variable(
'output_bias', [1], initializer=tf.zeros_initializer())

tvars = tf.trainable_variables()
tvars = tf.compat.v1.trainable_variables()

(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)

tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map)

minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1)
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)

with tf.variable_scope("pooling"):
with tf.compat.v1.variable_scope("pooling"):
if len(args.pooling_layer) == 1:
encoder_layer = model.all_encoder_layers[args.pooling_layer[0]]
else:
Expand Down Expand Up @@ -156,12 +156,12 @@ def optimize_graph(args, logger=None):

pooled = tf.identity(pooled, 'final_encodes')
output_tensors = [pooled]
tmp_g = tf.get_default_graph().as_graph_def()
tmp_g = tf.compat.v1.get_default_graph().as_graph_def()

with tf.Session(config=config) as sess:
with tf.compat.v1.Session(config=config) as sess:
logger.info('load parameters from checkpoint...')

sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
dtypes = [n.dtype for n in input_tensors]
logger.info('optimize...')
tmp_g = optimize_for_inference(
Expand All @@ -177,7 +177,7 @@ def optimize_graph(args, logger=None):

tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.graph_tmp_dir).name
logger.info('write graph to a tmp file: %s' % tmp_file)
with tf.gfile.GFile(tmp_file, 'wb') as f:
with tf.io.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString())
return tmp_file, bert_config
except Exception:
Expand Down