Skip to content

Commit

Permalink
- memoryIO.py allow loading of dataset in shared memory, compatible w…
Browse files Browse the repository at this point in the history
…ith multiprocessing iteration

- version 0.4.2
  • Loading branch information
jeanollion committed Apr 29, 2024
1 parent 948284c commit 4263158
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 9 deletions.
92 changes: 89 additions & 3 deletions dataset_iterator/datasetIO/memoryIO.py
@@ -1,25 +1,78 @@
from .datasetIO import DatasetIO
import threading
import numpy as np
from multiprocessing import managers, shared_memory, Value
from ..shared_memory import to_shm, get_idx_from_shm

_MEMORYIO_SHM_MANAGER = {}
_MEMORYIO_UID = None

class MemoryIO(DatasetIO):
def __init__(self, datasetIO:DatasetIO):
def __init__(self, datasetIO: DatasetIO, use_shm: bool = True):
super().__init__()
self.datasetIO = datasetIO
self.__lock__ = threading.Lock()
self.datasets=dict()
self.datasets = dict()
self.use_shm = use_shm
global _MEMORYIO_UID
if _MEMORYIO_UID is None:
try:
_MEMORYIO_UID = Value("i", 0)
except OSError: # In this case the OS does not allow us to use multiprocessing. We resort to an int for indexing.
_MEMORYIO_UID = 0

if isinstance(_MEMORYIO_UID, int):
self.uid = _MEMORYIO_UID
_MEMORYIO_UID += 1
else:
# Doing Multiprocessing.Value += x is not process-safe.
with _MEMORYIO_UID.get_lock():
self.uid = _MEMORYIO_UID.value
_MEMORYIO_UID.value += 1
if use_shm:
self._start_shm_manager()

def _start_shm_manager(self):
global _MEMORYIO_SHM_MANAGER
_MEMORYIO_SHM_MANAGER[self.uid] = managers.SharedMemoryManager()
_MEMORYIO_SHM_MANAGER[self.uid].start()
self.shm_manager_on = True

def _stop_shm_manager(self):
global _MEMORYIO_SHM_MANAGER
if _MEMORYIO_SHM_MANAGER[self.uid] is not None:
_MEMORYIO_SHM_MANAGER[self.uid].shutdown()
_MEMORYIO_SHM_MANAGER[self.uid].join()
_MEMORYIO_SHM_MANAGER[self.uid] = None
self.shm_manager_on = False

def _to_shm(self, array):
global _MEMORYIO_SHM_MANAGER
shapes, dtypes, shm_name, _ = to_shm(_MEMORYIO_SHM_MANAGER[self.uid], array)
return shapes[0], dtypes[0], shm_name

def close(self):
if self.use_shm:
for shma in self.datasets.values():
shma.unlink()
self.datasets.clear()
self.datasetIO.close()
if self.use_shm and self.shm_manager_on:
self._stop_shm_manager()

def get_dataset_paths(self, channel_keyword, group_keyword):
return self.datasetIO.get_dataset_paths(channel_keyword, group_keyword)

def get_dataset(self, path):
if path not in self.datasets:
with self.__lock__:
if self.use_shm and not self.shm_manager_on:
self._start_shm_manager()
if path not in self.datasets:
self.datasets[path] = self.datasetIO.get_dataset(path)[:] # load into memory
if self.use_shm:
self.datasets[path] = ShmArrayWrapper(*self._to_shm(self.datasetIO.get_dataset(path)[:]))
else:
self.datasets[path] = ArrayWrapper(self.datasetIO.get_dataset(path)[:]) # load into memory
return self.datasets[path]

def get_attribute(self, path, attribute_name):
Expand All @@ -36,3 +89,36 @@ def __contains__(self, key):

def get_parent_path(self, path):
self.datasetIO.get_parent_path(path)


class ArrayWrapper:
def __init__(self, array):
self.array = array
self.shape = array.shape

def __getitem__(self, item):
return np.copy(self.array[item])

def __len__(self):
return self.shape[0]


class ShmArrayWrapper:
def __init__(self, shape, dtype, shm_name):
self.shape = shape
self.dtype = dtype
self.shm_name = shm_name

def __getitem__(self, item):
assert isinstance(item, (int, np.integer)), f"only integer index supported: recieved: {item} of type: {type(item)}"
return get_idx_from_shm(item, (self.shape,), (self.dtype,), self.shm_name, array_idx=0)

def __len__(self):
return self.shape[0]

def unlink(self):
try:
existing_shm = shared_memory.SharedMemory(self.shm_name)
existing_shm.unlink()
except Exception:
pass
9 changes: 5 additions & 4 deletions dataset_iterator/hard_sample_mining.py
Expand Up @@ -40,11 +40,12 @@ def close(self):
self.iterator.close()

def on_epoch_begin(self, epoch, logs=None):
self.wait_for_me.clear() # will block
if self.proba_per_metric is not None:
self.wait_for_me.clear() # will block

def on_epoch_end(self, epoch, logs=None):
if self.period==1 or (epoch + 1 + self.start_epoch) % self.period == 0:
if (epoch > 0 or not self.skip_first) and epoch + self.start_epoch >= self.start_from_epoch:
if self.period == 1 or (epoch + 1 + self.start_epoch) % self.period == 0:
if (epoch > 0 or not self.skip_first) and epoch + 1 + self.start_epoch >= self.start_from_epoch:
self.target_iterator.close()
self.iterator.open()
metrics = self.compute_metrics()
Expand All @@ -63,7 +64,7 @@ def on_epoch_end(self, epoch, logs=None):
proba = self.proba_per_metric
# set probability to iterator in case of multiprocessing iwth OrderedEnqueeur this will be taken into account only a next epoch has iterator has already been sent to processes at this stage
self.target_iterator.set_index_probability(proba)
self.wait_for_me.set() # release block
self.wait_for_me.set() # release block

def on_train_end(self, logs=None):
self.close()
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -5,14 +5,14 @@

setuptools.setup(
name="dataset_iterator",
version="0.4.1",
version="0.4.2",
author="Jean Ollion",
author_email="jean.ollion@polytechnique.org",
description="Keras-style data iterator for images contained in dataset files such as hdf5 or PIL readable files. Images can be contained in several files.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/jeanollion/dataset_iterator.git",
download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.1/dataset_iterator-0.4.1.tar.gz',
download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.2/dataset_iterator-0.4.2.tar.gz',
keywords=['Iterator', 'Dataset', 'Image', 'Numpy'],
packages=setuptools.find_packages(),
classifiers=[ #https://pypi.org/classifiers/
Expand Down

0 comments on commit 4263158

Please sign in to comment.