Skip to content

A collection of LightGBM callbacks. (DART early stopping, tqdm progress bar)

License

Notifications You must be signed in to change notification settings

34j/lightgbm-callbacks

Repository files navigation

LightGBM Callbacks

CI Status Documentation Status Test coverage percentage

Poetry black pre-commit

PyPI Version Supported Python versions License

A collection of LightGBM callbacks. Provides implementations of ProgressBarCallback (#5867) and DartEarlyStoppingCallback (#4805), as well as an LGBMDartEarlyStoppingEstimator that automatically passes these callbacks. (#3313, #5808)

Installation

Install this via pip (or your favourite package manager):

pip install lightgbm-callbacks

Usage

SciKit-Learn API, simple

from lightgbm import LGBMRegressor
from lightgbm_callbacks import LGBMDartEarlyStoppingEstimator
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
LGBMDartEarlyStoppingEstimator(
    LGBMRegressor(boosting_type="dart"), # or "gbdt", ...
    stopping_rounds=10, # or n_iter_no_change=10
    test_size=0.2, # or validation_fraction=0.2
    shuffle=False,
    tqdm_cls="rich", # "auto", "autonotebook", ...
).fit(X_train, y_train)

Scikit-Learn API, manually passing callbacks

from lightgbm import LGBMRegressor
from lightgbm_callbacks import ProgressBarCallback, DartEarlyStoppingCallback
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train)
early_stopping_callback = DartEarlyStoppingCallback(stopping_rounds=10)
LGBMRegressor(
).fit(
    X_train,
    y_train,
    eval_set=[(X_train, y_train), (X_val, y_val)],
    callbacks=[
        early_stopping_callback,
        ProgressBarCallback(early_stopping_callback=early_stopping_callback),
    ],
)

Details on DartEarlyStoppingCallback

Below is a description of the DartEarlyStoppingCallback method parameter and lgb.plot_metric for each lgb.LGBMRegressor(boosting_type="dart", n_estimators=1000) trained with entire sklearn_datasets.load_diabetes() dataset.

Method Description iteration Image Actual iteration
(Baseline) If Early stopping is not used. n_estimators image 1000
"none" Do nothing and return the original estimator. min(best_iteration + early_stopping_rounds, n_estimators) image 50
"save" Save the best model by deepcopying the estimator and return the best model (using pickle). min(best_iteration + 1, n_estimators) image 21
"refit" Refit the estimator with the best iteration and return the refitted estimator. min(best_iteration, n_estimators) image 20

Contributors ✨

Thanks goes to these wonderful people (emoji key):

34j
34j

💻 🤔 📖

This project follows the all-contributors specification. Contributions of any kind welcome!