Skip to content

Commit

Permalink
- ordered_enqueuer_cf.py : use semaphore instead of queue and modifie…
Browse files Browse the repository at this point in the history
…d pool shutdown
  • Loading branch information
jeanollion committed Apr 23, 2024
1 parent 30943b0 commit 47c9149
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
72 changes: 43 additions & 29 deletions dataset_iterator/ordered_enqueuer_cf.py
Expand Up @@ -7,6 +7,7 @@
import threading
import time
from multiprocessing import managers, shared_memory
from threading import BoundedSemaphore

# adapted from https://github.com/keras-team/keras/blob/v2.13.1/keras/utils/data_utils.py#L651-L776
# uses concurrent.futures, solves a memory leak in case of hard sample mining run as callback with regular orderedEnqueur. Option to pass tensors through shared memory
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, sequence, shuffle=False, single_epoch:bool=False, use_shm:boo
self.run_thread = None
self.stop_signal = None
self.shm_manager = None
self.semaphore = None

def is_running(self):
return self.stop_signal is not None and not self.stop_signal.is_set()
Expand All @@ -62,7 +64,10 @@ def start(self, workers=1, max_queue_size=10):
(when full, workers could block on `put()`)
"""
self.workers = workers
self.queue = queue.Queue(max_queue_size)
if max_queue_size <= 0:
max_queue_size = self.workers
self.semaphore = BoundedSemaphore(max_queue_size)
self.queue = []
self.stop_signal = threading.Event()
if self.use_shm:
self.shm_manager = managers.SharedMemoryManager()
Expand All @@ -71,12 +76,16 @@ def start(self, workers=1, max_queue_size=10):
self.run_thread.daemon = True
self.run_thread.start()

def _wait_queue(self):
def _wait_queue(self, empty:bool):
"""Wait for the queue to be empty."""
while True:
time.sleep(0.1)
if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
if (empty and len(self.queue) == 0) or (not empty and len(self.queue) > 0) or self.stop_signal.is_set():
return
time.sleep(0.1)

def _task_done(self, _):
"""Called once task is done, releases the queue if blocked."""
self.semaphore.release()

def _run(self):
"""Submits request to the executor and queue the `Future` objects."""
Expand All @@ -87,15 +96,19 @@ def _run(self):
if self.shuffle:
random.shuffle(sequence)
task = get_item_shm if self.use_shm else get_item
with ProcessPoolExecutor(max_workers=self.workers, initializer=init_pool_generator, initargs=(self.sequence, self.uid, self.shm_manager)) as executor:
for i in sequence:
if self.stop_signal.is_set():
return
future = executor.submit(task, self.uid, i)
self.queue.put(future, block=True)
# Done with the current epoch, waiting for the final batches
self._wait_queue()

executor = ProcessPoolExecutor(max_workers=self.workers, initializer=init_pool_generator, initargs=(self.sequence, self.uid, self.shm_manager))
for idx, i in enumerate(sequence):
if self.stop_signal.is_set():
return
self.semaphore.acquire()
future = executor.submit(task, self.uid, i)
self.queue.append((future, i))
# Done with the current epoch, waiting for the final batches
self._wait_queue(True) # safer to wait before calling shutdown than calling directly shutdown with wait=True
print("exiting from ProcessPoolExecutor...", flush=True)
time.sleep(0.1)
executor.shutdown(wait=False, cancel_futures=True)
print("exiting from ProcessPoolExecutor done", flush=True)
if self.stop_signal.is_set() or self.single_epoch:
# We're done
return
Expand Down Expand Up @@ -124,20 +137,23 @@ def get(self):
`(inputs, targets, sample_weights)`.
"""
while self.is_running():
try:
inputs = self.queue.get(block=True, timeout=5).result()
if self.is_running():
self.queue.task_done()
if inputs is not None:
self._wait_queue(False)
if len(self.queue) > 0:
future, i = self.queue[0]
try:
inputs = future.result()
self.queue.pop(0) # only remove after result() is called to avoid terminating pool while a process is still running
if self.use_shm:
inputs = from_shm(*inputs)
yield inputs
except queue.Empty:
pass
except Exception as e:
self.stop()
print("Exception raised while getting future", flush=True)
raise e
self.semaphore.release() # release is done here and not as a future callback to limit effective number of samples in memory
except Exception as e:
self.stop()
print(f"Exception raised while getting future result from task: {i}", flush=True)
raise e
finally:
future.cancel()
del future
yield inputs

def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
Expand All @@ -148,14 +164,12 @@ def stop(self, timeout=None):
timeout: maximum time to wait on `thread.join()`
"""
self.stop_signal.set()
with self.queue.mutex:
self.queue.queue.clear()
self.queue.unfinished_tasks = 0
self.queue.not_full.notify()
self.run_thread.join(timeout)
if self.use_shm is not None:
self.shm_manager.shutdown()
self.shm_manager.join()
self.queue = None
self.semaphore = None
global _SHARED_SHM_MANAGER
_SHARED_SHM_MANAGER[self.uid] = None
global _SHARED_SEQUENCES
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -12,7 +12,7 @@
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/jeanollion/dataset_iterator.git",
download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.0/dataset_iterator-0.4.1.tar.gz',
download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.1/dataset_iterator-0.4.1.tar.gz',
keywords=['Iterator', 'Dataset', 'Image', 'Numpy'],
packages=setuptools.find_packages(),
classifiers=[ #https://pypi.org/classifiers/
Expand Down

0 comments on commit 47c9149

Please sign in to comment.