Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes thread safety in CustomObjectDetection not working #418

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 13 additions & 8 deletions imageai/Detection/Custom/__init__.py
Expand Up @@ -13,6 +13,7 @@
from imageai.Detection.Custom.utils.multi_gpu_model import multi_gpu_model
from imageai.Detection.Custom.gen_anchors import generateAnchors
import tensorflow as tf
from tensorflow.python.keras.backend import set_session
from keras.models import load_model, Input
from keras.callbacks import TensorBoard
import keras.backend as K
Expand Down Expand Up @@ -594,6 +595,8 @@ def __init__(self):
self.__input_size = 416
self.__object_threshold = 0.4
self.__nms_threshold = 0.4
self.__session = None
self.__graph = None
self.__model = None
self.__detection_utils = CustomDetectionUtils(labels=[])

Expand Down Expand Up @@ -621,7 +624,7 @@ def setJsonPath(self, configuration_json):
"""
self.__detection_config_json_path = configuration_json

def loadModel(self):
def loadModel(self, session_config=None):

"""
'loadModel' is used to load the model into the CustomObjectDetection class
Expand All @@ -636,13 +639,19 @@ def loadModel(self):

self.__detection_utils = CustomDetectionUtils(labels=self.__model_labels)

self.__session = tf.compat.v1.Session(config=session_config)

self.__graph = tf.compat.v1.get_default_graph()

self.__model = yolo_main(Input(shape=(None, None, 3)), 3, len(self.__model_labels))

set_session(self.__session)

self.__model.load_weights(self.__model_path)

def detectObjectsFromImage(self, input_image="", output_image_path="", input_type="file", output_type="file",
extract_detected_objects=False, minimum_percentage_probability=50, nms_treshold=0.4,
display_percentage_probability=True, display_object_name=True, thread_safe=False):
display_percentage_probability=True, display_object_name=True):

"""

Expand All @@ -656,7 +665,6 @@ def detectObjectsFromImage(self, input_image="", output_image_path="", input_typ
* nms_threshold (optional, o.45 by default) , option to set the Non-maximum suppression for the detection
* display_percentage_probability (optional, True by default), option to show or hide the percentage probability of each object in the saved/returned detected image
* display_display_object_name (optional, True by default), option to show or hide the name of each object in the saved/returned detected image
* thread_safe (optional, False by default), enforce the loaded detection model works across all threads if set to true, made possible by forcing all Keras inference to run on the default graph


The values returned by this function depends on the parameters parsed. The possible values returnable
Expand Down Expand Up @@ -708,7 +716,6 @@ def detectObjectsFromImage(self, input_image="", output_image_path="", input_typ
:param nms_treshold:
:param display_percentage_probability:
:param display_object_name:
:param thread_safe:
:return image_frame:
:return output_objects_array:
:return detected_objects_image_array:
Expand Down Expand Up @@ -763,10 +770,8 @@ def detectObjectsFromImage(self, input_image="", output_image_path="", input_typ
image = np.expand_dims(image, 0)

if self.__model_type == "yolov3":
if thread_safe == True:
with K.get_session().graph.as_default():
yolo_results = self.__model.predict(image)
else:
with self.__graph.as_default():
set_session(self.__session)
yolo_results = self.__model.predict(image)

boxes = list()
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
@@ -1,5 +1,5 @@
tensorflow
keras
tensorflow==1.14.0
keras==2.2.5
numpy
pillow
scipy
Expand Down
8 changes: 7 additions & 1 deletion setup.py
Expand Up @@ -8,7 +8,13 @@
author_email='guymodscientist@gmail.com',
license='MIT',
packages= find_packages(),
install_requires=['numpy','scipy','pillow',"matplotlib", "h5py"],
install_requires=['numpy',
'scipy',
'pillow',
'matplotlib',
'h5py',
'tensorflow==1.14.0',
'keras==2.2.5'],
zip_safe=False

)