Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

slow, single threaded #459

Open
bbudescu opened this issue Mar 18, 2019 · 3 comments
Open

slow, single threaded #459

bbudescu opened this issue Mar 18, 2019 · 3 comments

Comments

@bbudescu
Copy link

bbudescu commented Mar 18, 2019

Hi, I've been trying to use ngraph to accelerate my tensorflow detector/testing pipeline, but, unfortunately, without any success so far. The inference process either has the same performance, or becomes painstakingly slow.

I'm not quite sure whether I'm installing and using ngraph right.

I'm not quite sure whether this is the right place to ask these questions, since it might just be something obvious that I've missed, thus not being an actual issue, but I couldn't find any other support channel. If there is a different, more appropriate one, please direct me to it.

For installation, I've used pip inside my own dockerfile to install ngraph-tensorflow-bridge, following the instructions on this repo (and also installed plaidml, since I've noticed ngraph looks for it during initialization; I didn't build the ngraph library myself, since I noticed that the bridge supplies the .so, and it doesn't complain when loading it).

Also, I've tried turning on xla, but it has no effect

config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
tf.keras.backend.set_session(tf.Session(config=config))

Also, I've tested ngraph with intel-tensorflow. When ngraph is off, intel-tf gets about twice as fast as vanilla tf. When ngraph bridge is imported, the performance is really, really low (i.e., I've got bored of waiting for an operation to finish that takes a few seconds when ngraph is not used).

Also, I've tried both 'NCHW' and 'NHWC' under both the vanilla and the intel distributions of tensorflow.

For usage, I only added import ngraph_bridge after importing tensorflow. Is there something else I'm supposed to do?

I didn't get any stdout/stderr message to help me figure out whether ngraph is actually on or not. I've looked through the output of tensorflow.python.client.device_lib.list_local_devices(), but nothing seems to change when adding the import. The only indication that ngraph is used is when I don't disable my GPU (os.environ["CUDA_VISIBLE_DEVICES"] = ""), and I get an error message.

Here is the code I've used for testing out ngraph (it's based on the keras example in this repo). I think the longest I've waited to see some training progress was 10 minutes (without ngraph, I get a progress bar update after under half a minute).

import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["KMP_AFFINITY"] = "granularity=fine,verbose,compact,1,0"
os.environ['KERAS_BACKEND'] = 'tensorflow'

import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
import ngraph_bridge

# A simple script to run inference and training on resnet 50

config = tf.ConfigProto()
config.intra_op_parallelism_threads = 4
config.inter_op_parallelism_threads = 4

# config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

tf.keras.backend.set_session(tf.Session(config=config))
# tf.keras.backend.set_image_data_format('channels_first')

model = ResNet50(weights=None)

batch_size = 128
img = np.random.rand(batch_size, 224, 224, 3)
# img = np.random.rand(batch_size, 3, 224, 224)

preds = model.predict(preprocess_input(img))
print('Predicted:', decode_predictions(preds, top=3)[0])
model.compile(tf.keras.optimizers.SGD(), loss='categorical_crossentropy')
preds = model.fit(
    preprocess_input(img), np.zeros((batch_size, 1000), dtype='float32'))
print('Ran a train round')

I've also tried ngraph for a different code that doesn't use keras (it uses tensorflow's object_detection API instead). Speed is at least 20% lower when using ngraph. For some models the process would have a much larger memory footprint when using ngraph vs. when not (I noticed that because my laptop started paging and crashed).

I've also noticed that while training the keras examples, only one of the 8 logical cores of my cpu is used. This happens when running inference on the detection model, but a fraction of the time more than one core is saturated.

Thanks

@avijit-nervana
Copy link
Contributor

Thanks for reporting this issue and providing the details. We will reproduce this and get back to you.

@bbudescu
Copy link
Author

Is this something that can be reproduced, i.e., a known bug or something, or did I just do something seriously stupid, like erroneous config/install?

@jaebaek
Copy link

jaebaek commented Apr 25, 2019

If you run the following code with tf bridge, it does not use GPU but uses multi-threads of CPU.
I am not sure I correctly built nGraph + TF, but under my setup it is 10 times slower than TF + CPU.

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants