Skip to content

Commit

Permalink
Update get_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Feb 27, 2024
1 parent 82aa824 commit cad1d68
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions hypergbm/hyper_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,29 @@
logger = logging.get_logger(__name__)

GB = 1024 ** 3
SAMPLERS = {}

try:
from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler, NearMiss, TomekLinks, EditedNearestNeighbours

imblearn_installed = True
im_samplers = {'RandomOverSampler': RandomOverSampler,
'SMOTE': SMOTE,
'ADASYN': ADASYN,
'RandomUnderSampler': RandomUnderSampler,
'NearMiss': NearMiss,
'TomekLinks': TomekLinks,
'EditedNearestNeighbours': EditedNearestNeighbours
}
SAMPLERS.update(im_samplers)
except:
logger.warning('Failed to load imbalanced-learn', exc_info=sys.exc_info())
imblearn_installed = False


def get_sampler(sampler):
if imblearn_installed:
samplers = {'RandomOverSampler': RandomOverSampler,
'SMOTE': SMOTE,
'ADASYN': ADASYN,
'RandomUnderSampler': RandomUnderSampler,
'NearMiss': NearMiss,
'TomekLinks': TomekLinks,
'EditedNearestNeighbours': EditedNearestNeighbours
}
sampler_cls = samplers.get(sampler)
if sampler_cls is not None:
return sampler_cls()
else:
return None
sampler_cls = SAMPLERS.get(sampler)
if sampler_cls is not None:
return sampler_cls()
else:
return None

Expand Down

0 comments on commit cad1d68

Please sign in to comment.