Skip to content

Commit

Permalink
Update make_experiment utility
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Feb 28, 2024
1 parent cad1d68 commit 97d75e6
Showing 1 changed file with 35 additions and 35 deletions.
70 changes: 35 additions & 35 deletions hypergbm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from hypernets.tabular import get_tool_box
from hypernets.utils import DocLens, isnotebook, logging

from hypergbm.hyper_gbm import HyperGBMShapExplainer, HyperGBMEstimator
from hypergbm.hyper_gbm import HyperGBMShapExplainer, HyperGBMEstimator, SAMPLERS

try:
import shap
from shap import TreeExplainer, KernelExplainer, Explainer

has_shap = True
except:
has_shap = False
Expand Down Expand Up @@ -114,30 +115,16 @@ def make_experiment(train_data,
tb = get_tool_box(pd.DataFrame)
else:
tb = get_tool_box(train_data)

if hyper_model_cls is None:
if data_adaption_to_cuml or tb.__name__.lower().find('cuml') >= 0:
from hypergbm.cuml import CumlHyperGBM
hyper_model_cls = CumlHyperGBM
else:
from hypergbm.hyper_gbm import HyperGBM
hyper_model_cls = HyperGBM
if data_adaption_to_cuml or tb.__name__.lower().find('cuml') >= 0:
from hypergbm.cuml import CumlHyperGBM
hyper_model_cls = CumlHyperGBM
else:
from hypergbm.hyper_gbm import HyperGBM
hyper_model_cls = HyperGBM

def default_search_space():
args = search_space_options if search_space_options is not None else {}
if estimator_early_stopping_rounds is not None:
assert isinstance(estimator_early_stopping_rounds, int), \
f'estimator_early_stopping_rounds should be int or None, {estimator_early_stopping_rounds} found.'
args['early_stopping_rounds'] = estimator_early_stopping_rounds

for key in ('n_estimators', 'class_balancing'):
if key in kwargs.keys():
args[key] = kwargs.pop(key)

for key in ('verbose',):
if key in kwargs.keys():
args[key] = kwargs.get(key)

if tb.__name__.lower().find('dask') >= 0:
from hypergbm.dask.search_space import search_space_general as dask_search_space
result = dask_search_space
Expand All @@ -148,9 +135,6 @@ def default_search_space():
from hypergbm.search_space import search_space_general as sk_search_space
result = sk_search_space

if args:
result = copy.deepcopy(result)
result.options.update(args)
return result

if (searcher is None or isinstance(searcher, str)) and search_space is None:
Expand All @@ -164,7 +148,27 @@ def default_search_space():
catboost_init_kwargs = search_space.options.get('catboost_init_kwargs', {})
catboost_init_kwargs['max_ctr_complexity'] = 1 # reduce training memory
search_space.options['catboost_init_kwargs'] = catboost_init_kwargs
logger.info(f'search space options: {search_space.options}')
# logger.info(f'search space options: {search_space.options}')

if search_space is not None:
search_space_kwargs = search_space_options or {}
if estimator_early_stopping_rounds is not None:
assert isinstance(estimator_early_stopping_rounds, int), \
f'estimator_early_stopping_rounds should be int or None, {estimator_early_stopping_rounds} found.'
search_space_kwargs['early_stopping_rounds'] = estimator_early_stopping_rounds

for key in ('n_estimators', 'class_balancing'):
if key in kwargs.keys():
search_space_kwargs[key] = kwargs.pop(key)

for key in ('verbose',):
if key in kwargs.keys():
search_space_kwargs[key] = kwargs.get(key)

if search_space_kwargs:
logger.info(f'update search space with options: {search_space_kwargs}')
search_space = copy.deepcopy(search_space)
search_space.options.update(search_space_kwargs)

def is_notebook_widget_ready():
try:
Expand Down Expand Up @@ -252,15 +256,11 @@ def default_search_callbacks():
or hypergbm.dask.search_space.search_space_general (if Dask is enabled)."""

_class_balancing_doc = """ : str, optional, (default=None)
Strategy for imbalanced learning (classification task only). Possible values:
- ClassWeight
- RandomOverSampler
- SMOTE
- ADASYN
- RandomUnderSampler
- NearMiss
- TomekLinks
- EditedNearestNeighbours"""
Strategy for imbalanced learning (classification task only), sampler name or bool.
Possible sampler names: {sampler_names}.
""".rstrip().format(
sampler_names=', '.join(map("'{}'".format, SAMPLERS.keys()))
)

_cross_validator_doc = """ : cross-validation generator, optional
Used to split a fit_transformed dataset into a sequence of train and test portions.
Expand Down

0 comments on commit 97d75e6

Please sign in to comment.