Skip to content

Commit

Permalink
bugfix hardsample mining
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanollion committed May 3, 2024
1 parent 4263158 commit c4adaee
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
31 changes: 17 additions & 14 deletions dataset_iterator/hard_sample_mining.py
Expand Up @@ -7,12 +7,11 @@
import threading

class HardSampleMiningCallback(tf.keras.callbacks.Callback):
def __init__(self, iterator, target_iterator, predict_fun, metrics_fun, period:int=10, start_epoch:int=0, skip_first:bool=False, start_from_epoch:int=0, enrich_factor:float=10., quantile_max:float=0.99, quantile_min:float=None, disable_channel_postprocessing:bool=False, workers=None, verbose:int=1):
def __init__(self, iterator, target_iterator, predict_fun, metrics_fun, period:int=10, start_epoch:int=0, start_from_epoch:int=0, enrich_factor:float=10., quantile_max:float=0.99, quantile_min:float=None, disable_channel_postprocessing:bool=False, workers=None, verbose:int=1):
super().__init__()
self.period = period
self.start_epoch = start_epoch
self.start_from_epoch = start_from_epoch
self.skip_first = skip_first
self.iterator = iterator
self.target_iterator = target_iterator
self.predict_fun = predict_fun
Expand All @@ -32,39 +31,43 @@ def __init__(self, iterator, target_iterator, predict_fun, metrics_fun, period:i
self.n_batches = len(simple_iterator)
self.enq = OrderedEnqueuerCF(simple_iterator, single_epoch=True, shuffle=False)
self.wait_for_me = threading.Event()
self.wait_for_me.set()

def close(self):
self.enq.stop()
if self.data_aug_param is not None:
self.iterator.enable_random_transforms(self.data_aug_param)
self.iterator.close()

def need_compute(self, epoch):
return epoch + 1 + self.start_epoch >= self.start_from_epoch and (self.period == 1 or (epoch + 1 + self.start_epoch - self.start_from_epoch) % self.period == 0)

def on_epoch_begin(self, epoch, logs=None):
if self.proba_per_metric is not None:
self.wait_for_me.clear() # will block
if self.proba_per_metric is not None or self.need_compute(epoch):
self.wait_for_me.clear() # will lock

def on_epoch_end(self, epoch, logs=None):
if self.period == 1 or (epoch + 1 + self.start_epoch) % self.period == 0:
if (epoch > 0 or not self.skip_first) and epoch + 1 + self.start_epoch >= self.start_from_epoch:
if self.need_compute(epoch):
if self.target_iterator is not self.iterator:
self.target_iterator.close()
self.iterator.open()
metrics = self.compute_metrics()
metrics = self.compute_metrics()
if self.target_iterator is not self.iterator:
self.iterator.close()
self.target_iterator.open()
first = self.proba_per_metric is None
self.proba_per_metric = get_index_probability(metrics, enrich_factor=self.enrich_factor, quantile_max=self.quantile_max, quantile_min=self.quantile_min, verbose=self.verbose)
self.n_metrics = self.proba_per_metric.shape[0] if len(self.proba_per_metric.shape) == 2 else 1
if first and self.n_metrics > self.period:
warnings.warn(f"Hard sample mining period = {self.period} should be greater than metric number = {self.n_metrics}")
first = self.proba_per_metric is None
self.proba_per_metric = get_index_probability(metrics, enrich_factor=self.enrich_factor, quantile_max=self.quantile_max, quantile_min=self.quantile_min, verbose=self.verbose)
self.n_metrics = self.proba_per_metric.shape[0] if len(self.proba_per_metric.shape) == 2 else 1
if first and self.n_metrics > self.period:
warnings.warn(f"Hard sample mining period = {self.period} should be greater than metric number = {self.n_metrics}")
if self.proba_per_metric is not None:
if len(self.proba_per_metric.shape) == 2:
self.metric_idx = (self.metric_idx + 1) % self.n_metrics
proba = self.proba_per_metric[self.metric_idx]
else:
proba = self.proba_per_metric
# set probability to iterator in case of multiprocessing iwth OrderedEnqueeur this will be taken into account only a next epoch has iterator has already been sent to processes at this stage
self.target_iterator.set_index_probability(proba)
self.wait_for_me.set() # release block
self.wait_for_me.set() # release lock

def on_train_end(self, logs=None):
self.close()
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -5,14 +5,14 @@

setuptools.setup(
name="dataset_iterator",
version="0.4.2",
version="0.4.3",
author="Jean Ollion",
author_email="jean.ollion@polytechnique.org",
description="Keras-style data iterator for images contained in dataset files such as hdf5 or PIL readable files. Images can be contained in several files.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/jeanollion/dataset_iterator.git",
download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.2/dataset_iterator-0.4.2.tar.gz',
download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.3/dataset_iterator-0.4.3.tar.gz',
keywords=['Iterator', 'Dataset', 'Image', 'Numpy'],
packages=setuptools.find_packages(),
classifiers=[ #https://pypi.org/classifiers/
Expand Down

0 comments on commit c4adaee

Please sign in to comment.