Skip to content

Commit

Permalink
Make imbalanced-learn optional
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Feb 18, 2024
1 parent e5777a7 commit aa44d23
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions hypergbm/hyper_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@
"""
import copy
import hashlib
import numpy as np
import pandas as pd
import pickle
import re
import sys
import time
from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler, NearMiss, TomekLinks, EditedNearestNeighbours
from sklearn import pipeline as sk_pipeline
from sklearn.inspection import permutation_importance as sk_pi
from sklearn.utils import Bunch
from tqdm.auto import tqdm

from hypergbm.gbm_callbacks import FileMonitorCallback
import numpy as np
import pandas as pd
from hypernets.core import Callback, ProgressiveCallback
from hypernets.model.estimator import Estimator
from hypernets.model.hyper_model import HyperModel
from hypernets.pipeline.base import ComposeTransformer
from hypernets.tabular import get_tool_box
from hypernets.tabular.cache import cache
from hypernets.utils import logging, fs, const
from sklearn import pipeline as sk_pipeline
from sklearn.inspection import permutation_importance as sk_pi
from sklearn.utils import Bunch
from tqdm.auto import tqdm

from hypergbm.gbm_callbacks import FileMonitorCallback
from .cfg import HyperGBMCfg as cfg
from .estimators import HyperEstimator

Expand All @@ -39,19 +39,31 @@

GB = 1024 ** 3

try:
from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler, NearMiss, TomekLinks, EditedNearestNeighbours

imblearn_istalled = True
except:
logger.warning('Failed to load imbalanced-learn', exc_info=sys.exc_info())
imblearn_istalled = False


def get_sampler(sampler):
samplers = {'RandomOverSampler': RandomOverSampler,
'SMOTE': SMOTE,
'ADASYN': ADASYN,
'RandomUnderSampler': RandomUnderSampler,
'NearMiss': NearMiss,
'TomekLinks': TomekLinks,
'EditedNearestNeighbours': EditedNearestNeighbours
}
sampler_cls = samplers.get(sampler)
if sampler_cls is not None:
return sampler_cls()
if imblearn_istalled:
samplers = {'RandomOverSampler': RandomOverSampler,
'SMOTE': SMOTE,
'ADASYN': ADASYN,
'RandomUnderSampler': RandomUnderSampler,
'NearMiss': NearMiss,
'TomekLinks': TomekLinks,
'EditedNearestNeighbours': EditedNearestNeighbours
}
sampler_cls = samplers.get(sampler)
if sampler_cls is not None:
return sampler_cls()
else:
return None
else:
return None

Expand Down

0 comments on commit aa44d23

Please sign in to comment.