Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't get deterministic results with NGC 20.09 #29

Open
ornellamarciano opened this issue Oct 13, 2020 · 8 comments
Open

Don't get deterministic results with NGC 20.09 #29

ornellamarciano opened this issue Oct 13, 2020 · 8 comments

Comments

@ornellamarciano
Copy link

ornellamarciano commented Oct 13, 2020

Prerequisites

I am using latest TensorFlow NGC container nvcr.io/nvidia/tensorflow:20.09-tf2-py3

Issue

When running my Deep Learning model on 1 GPU (Tesla T4) within the NGC container 20.09, I don't get reproductible results between consecutive runs (different predictions) although all the seeds are set correctly. Running the same code on CPU gives reproductible results over consecutive runs.

On a small dataset (138k samples), I managed to get reproductible results within the NGC by setting the default float type to float64 (instead of float32).
But with a bigger dataset (2.6m samples), after 1million of samples, the predictions start to differ over different runs.
The more I increase the dataset size, the bigger is the difference between predictions of different consecutive runs.

My datasets are tfrecord files of hashed values, serialized by batch of 10,000 samples.
Here are the small_dataset (size 13M):
small_dataset
and the big_dataset (size 251M) - can you send me your email to share it with you via Google Drive to be able to reproduce the issue?
You need to unzip them before running the script bellow.

Command lines:

sudo docker run --gpus all --shm-size=1g --ulimit memlock=-1 -it --rm -v /home:/workspace/home nvcr.io/nvidia/tensorflow:20.09-tf2-py3

CUDA_VISIBLE_DEVICES="0" TF_DETERMINISTIC_OPS="1" TF_CPP_MIN_LOG_LEVEL="3" python code.py data_file prediction_file

Script

import tensorflow as tf
import sys
import string
import os
from tensorflow.python.keras.initializers import RandomNormal, GlorotUniform, Zeros, glorot_normal
from tensorflow.python.keras.regularizers import l2
from tensorflow.compat.v1.keras.layers import Layer, Embedding, Input, Dense, Flatten, Add
from tensorflow.python.keras import backend as K
import numpy as np
import logging
tf.get_logger().setLevel(logging.ERROR)
K.set_floatx('float64')

prediction_output_name = "pred"

SEED = 1024
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
tf.compat.v1.set_random_seed(SEED)


class Linear(Layer):

    def call(self, x):
        return tf.reduce_sum(x, axis=[1])

    def compute_output_shape(self):
        return None, 1


class FM(Layer):

    def __init__(self, **kwargs):

        super(FM, self).__init__(**kwargs)

    def build(self, input_shape):
        if len(input_shape) != 3:
            raise ValueError("Unexpected inputs dimensions % d,\
                              expect to be 3 dimensions" % (len(input_shape)))
        super(FM, self).build(input_shape)

    def call(self, inputs):
        if K.ndim(inputs) != 3:
            raise ValueError(
                "Unexpected inputs dimensions %d, expect to be 3 dimensions"
                % (K.ndim(inputs)))

        concated_embeds_value = inputs
        square_of_sum = tf.square(tf.compat.v1.reduce_sum(
            concated_embeds_value, axis=1, keep_dims=True))
        sum_of_square = tf.compat.v1.reduce_sum(
            concated_embeds_value * concated_embeds_value, axis=1, keep_dims=True)
        cross_term = square_of_sum - sum_of_square
        cross_term = 0.5 * tf.compat.v1.reduce_sum(cross_term, axis=2, keep_dims=False)

        return cross_term

    def compute_output_shape(self):
        return None, 1


class DNN(Layer):

    def __init__(self, hidden_units, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False, seed=1024, **kwargs):
        self.hidden_units = hidden_units
        self.activation = activation
        self.dropout_rate = dropout_rate
        self.seed = seed
        self.l2_reg = l2_reg
        self.use_bn = use_bn
        super(DNN, self).__init__(**kwargs)

    def build(self, input_shape):
        input_size = input_shape[-1]
        hidden_units = [int(input_size)] + list(self.hidden_units)
        self.kernels = [self.add_weight(name='kernel' + str(i),
                                        shape=(
                                            hidden_units[i], hidden_units[i + 1]),
                                        initializer=glorot_normal(
                                            seed=self.seed),
                                        # initializer=Constant(0.4),
                                        regularizer=l2(self.l2_reg),
                                        trainable=True) for i in range(len(self.hidden_units))]
        self.bias = [self.add_weight(name='bias' + str(i),
                                     shape=(self.hidden_units[i],),
                                     initializer=Zeros(),
                                     trainable=True) for i in range(len(self.hidden_units))]
        if self.use_bn:
            self.bn_layers = [tf.keras.layers.BatchNormalization() for _ in range(len(self.hidden_units))]

        self.dropout_layers = [tf.keras.layers.Dropout(self.dropout_rate, seed=self.seed + i) for i in
                               range(len(self.hidden_units))]

        self.activation_layers = [tf.keras.layers.Activation(self.activation) for _ in range(len(self.hidden_units))]
        super(DNN, self).build(input_shape)

    def call(self, inputs, training=None):

        deep_input = inputs

        for i in range(len(self.hidden_units)):
            fc = tf.math.add(tf.tensordot(
                deep_input, self.kernels[i], axes=(-1, 0)), self.bias[i])

            if self.use_bn:
                fc = self.bn_layers[i](fc, training=training)

            fc = self.activation_layers[i](fc)

            fc = self.dropout_layers[i](fc, training=training)
            deep_input = fc

        return deep_input

    def compute_output_shape(self, input_shape):
        if len(self.hidden_units) > 0:
            shape = input_shape[:-1] + (self.hidden_units[-1],)
        else:
            shape = input_shape

        return tuple(shape)

    def get_config(self, ):
        config = {'activation': self.activation, 'hidden_units': self.hidden_units,
                  'l2_reg': self.l2_reg, 'use_bn': self.use_bn, 'dropout_rate': self.dropout_rate, 'seed': self.seed}
        base_config = super(DNN, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class PredictionLayer(Layer):

    def __init__(self, task='binary', use_bias=True, **kwargs):
        if task not in ["binary", "multiclass", "regression"]:
            raise ValueError("task must be binary,multiclass or regression")
        self.task = task
        self.use_bias = use_bias
        super(PredictionLayer, self).__init__(**kwargs)

    def build(self, input_shape):

        if self.use_bias:
            self.global_bias = self.add_weight(
                shape=(1,), initializer=Zeros(), name="global_bias")

        super(PredictionLayer, self).build(input_shape)

    def call(self, inputs):
        x = inputs
        if self.use_bias:
            x = tf.nn.bias_add(x, self.global_bias, data_format='NHWC')
        if self.task == "binary":
            x = tf.sigmoid(x)

        output = tf.cast(tf.reshape(x, (-1, 1)), tf.float32)

        return output

    def compute_output_shape(self):
        return None, 1

    def get_config(self, ):
        config = {'task': self.task, 'use_bias': self.use_bias}
        base_config = super(PredictionLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def input_pipeline(dataset_file, serialized_features, epochs):
    dataset = create_dataset(dataset_file, serialized_features, epochs)
    dataset.prefetch(1)
    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
    training_init_op = iterator.make_initializer(dataset)
    return iterator, training_init_op


def create_dataset(dataset_file, serialized_features, epochs):
    decode_func = make_decoder(serialized_features)
    return tf.data.TFRecordDataset(dataset_file).map(decode_func).batch(1).repeat(epochs)


def make_decoder(serialized_features):
    features_dict = {f: tf.io.VarLenFeature(dtype=tf.int64) for f in serialized_features}
    features_dict['label'] = tf.io.VarLenFeature(dtype=tf.int64)

    def decode(serialized_example):
        features = tf.io.parse_example(
            [serialized_example],
            features_dict
        )
        return features
    return decode


def select_features(features, hash_size, model_features):
    selected_features = []
    for feature_name in model_features:
        feature_tensor = features.get(feature_name[0])
        dense_feature_tensor = tf.transpose(tf.sparse.to_dense(tf.cast(feature_tensor, dtype=tf.int32)))
        dense_feature_tensor = tf.math.mod(dense_feature_tensor, hash_size)
        selected_features.append(dense_feature_tensor)
    return selected_features


def build_optimizer(learning_rate, optimizer):
    if optimizer == 'RMSProp':
        opt = tf.compat.v1.train.RMSPropOptimizer(learning_rate=learning_rate)
    elif optimizer == 'adagrad':
        opt = tf.compat.v1.train.AdagradOptimizer(learning_rate=learning_rate)
    elif optimizer == 'adam':
        opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    else:
        opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
    return opt


def predictions_saver(prediction_path, predictions):
    with open(prediction_path, "a") as output_file:
        output_file.write('\n'.join([str(x) for x in predictions]) + '\n')


def summarize_weights(session):
  if hasattr(session, 'raw_session'): session = session.raw_session()
  weights = session.run(tf.compat.v1.trainable_variables())
  summary = sum(map(lambda x: x.sum(), weights))
  print("Summary of weights: %.20f" % summary)


if __name__ == '__main__':
    dataset_file = sys.argv[1]
    prediction_path = sys.argv[2]
    header = list(string.ascii_uppercase) + list(string.ascii_lowercase) + [str(i) for i in range (10)]
    epochs = 1
    hash_size = 2 ** 25
    model_features = "A-B-C-D".split('-')
    init_std = 0.01
    embedding_size = 4
    l2_reg_lr, l2_reg_nn, l2_reg_emb = 0, 0, 0
    hidden_units_num = 200
    hidden_layers_num = 4
    dnn_activation_function = "relu"
    dnn_dropout = 0
    learning_rate = 0.0004
    optimizer = "adam"

    graph = tf.Graph()
    with graph.as_default():
        # Input pipeline
        iterator, training_init_op = input_pipeline(dataset_file, header, epochs)
        features = iterator.get_next()
        selected_features = select_features(features, hash_size, model_features)
        stacked_features = tf.stack(selected_features, axis=1)
        inputs = tf.reshape(stacked_features, [-1, len(model_features)])
        label = tf.reshape(tf.sparse.to_dense(features['label']), [-1])

        # Linear part
        emb_linear = tf.compat.v1.keras.layers.Embedding(
            hash_size,
            1,
            embeddings_initializer=RandomNormal(mean=0.0, stddev=init_std, seed=SEED),
            embeddings_regularizer=l2(l2_reg_emb),
            embeddings_constraint=None,
            mask_zero=False,
            input_length=len(model_features)
        )(inputs)

        linear_logit = Linear(l2_reg_lr)(emb_linear)

        # Embeddings for FM / DNN
        emb_fm = tf.compat.v1.keras.layers.Embedding(
            hash_size,
            embedding_size,
            embeddings_initializer=RandomNormal(mean=0.0, stddev=init_std, seed=SEED),
            embeddings_regularizer=l2(l2_reg_emb),
            embeddings_constraint=None,
            mask_zero=False,
            input_length=len(model_features)
        )(inputs)

        # FM part
        fm_logit = FM()(emb_fm)

        # DNN part
        dnn_input = Flatten()(emb_fm)
        dnn_output = DNN([hidden_units_num] * hidden_layers_num, activation=dnn_activation_function, l2_reg=l2_reg_nn,
                    dropout_rate=dnn_dropout, seed=SEED)(dnn_input)
        # dnn_logit = Dense(1, use_bias=False, activation=None, kernel_initializer="ones")(dnn_output)
        dnn_logit = Dense(1, use_bias=False, activation=None, kernel_initializer=GlorotUniform(seed=SEED))(dnn_output)

        # Add together
        final_logit = Add()([linear_logit, fm_logit, dnn_logit])

        # Prediction layer
        pred = tf.reshape(PredictionLayer('binary')(final_logit), [-1], name=prediction_output_name)
        loss = tf.identity(tf.compat.v1.losses.log_loss(label, pred), name='LOSS')
        opt = build_optimizer(learning_rate, optimizer)
        optimizer = opt.minimize(loss)

        init_all_vars = tf.compat.v1.global_variables_initializer()
        saver = tf.compat.v1.train.Saver()

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        session = tf.compat.v1.Session(graph=graph, config=config)

        session.run(init_all_vars)
        session.run(training_init_op)
        summarize_weights(session)

        while True:
            try:
                inp, predictions, _ = session.run([inputs, pred, optimizer])
                summarize_weights(session)

                with open(prediction_path, "a") as output_file:
                    output_file.write('\n'.join([str(x) for x in predictions]) + '\n')
            except tf.errors.OutOfRangeError:
                break

        session.close()

Could you help me to understand why I don't get the exact same predictions after consecutive runs ?

@ornellamarciano ornellamarciano changed the title Don't get determinist results with NGC 20.09 Don't get deterministic results with NGC 20.09 Oct 13, 2020
@duncanriach
Copy link
Collaborator

duncanriach commented Oct 14, 2020

That's a very thorough report, @ornellamarciano; thank you. Having scanned your code, what stands out to me is your use of tf.compat.v1.keras.layers.Embedding. We have isolated a source of nondeterminism to the backprop into embedding matrices (in TensorFlow) and we're actively working on a patch to address this. We plan to release soon, at which point I'll ask you to test again with the patch.

Deeper information: the forward path of selecting word embeddings for a given batch is usually implemented with tf.gather (and I've confirmed that this is true for tf.compat.v1.keras.layers.Embedding). The backprop for tf.gather produces a tf.IndexedSlices gradient tensor which is reduced to a dense tensor, in order to update the embedding matrix, using tf.convert_to_tensor. tf.convert_to_tensor uses one of the segment sum ops, which we have confirmed operates nondeterministically. The upcoming patch, released in version 0.4.0 of the framework-determinism package and applied using fwd9m.tensorflow.enable_determinism(), will dynamically patch the segment sum ops to make them operate deterministically for 16-bit floating-point and 32-bit floating-point operands. The patch will not address nondeterminism in 64-bit floating-point operands; that will be addressed by a later CUDA-level fix (I don't know when we'll get to that).

Thanks to my colleague @wenscarl for helping to isolate the source of nondeterminism and develop the patch for it.

@ornellamarciano
Copy link
Author

@duncanriach @wenscarl Thanks a lot for your quick and detailed answer !
Regarding float64, I tried it to escape non-determinism due to a potential issue in floating point. But if with the new patch, I manage to get deterministic results, I will totally work with float32 (faster and less memory consumption).
Thanks again.

@duncanriach
Copy link
Collaborator

Right, I would expect you to see a much smaller amount of noise accumulating when using float64, which seems to be what you witnessed.

@ornellamarciano
Copy link
Author

Hi @duncanriach do you know when the patch solving the non determinism of embedding layers will be released ? Thanks for your help!

@duncanriach
Copy link
Collaborator

I don't have an ETA for you, but I can tell you that we're actively working on it and that's it's top priority. I'm hoping for a release in the next few weeks.

@ornellamarciano
Copy link
Author

Hi @duncanriach do you have any update ? Thanks a lot!

@duncanriach
Copy link
Collaborator

duncanriach commented Jul 7, 2021

Hi @ornellamarciano,

Thanks for checking-in.

From stock TensorFlow version 2.5 onwards, the use of TensorFlow's segment reduction ops running on a GPU, with the expectation of reproducibility (i.e. when TF_DETERMINISTIC_OPS is set to '1' or 'true'), will cause tf.errors.UnimplementedError to be thrown, this includes back-propagating into a dense embedding matrix via the backwards path through tf.gather (which uses tf.math.unsorted_segment_sum). See stock TensorFlow pull request 47772 for more information.

@benbarsdell has developed deterministic GPU implementations of all the TensorFlow segment reduction ops, including tf.math.unsorted_segment_sum. A pull request to add this to stock TensorFlow will be created soon (there are some dependencies that it's waiting for). I expect these changes to be released in stock TensorFlow version 2.7.

@duncanriach
Copy link
Collaborator

You might also want to try applying the unreleased patch from the cloned repo. See the discussion in issue 19.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants