Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to early stop in GradientBoostingClassifer? #25859

Open
Nitinsiwach opened this issue Mar 15, 2023 · 1 comment
Open

How to early stop in GradientBoostingClassifer? #25859

Nitinsiwach opened this issue Mar 15, 2023 · 1 comment
Labels

Comments

@Nitinsiwach
Copy link

Nitinsiwach commented Mar 15, 2023

How to track the model performance on an eval set that is provided from outside and early stop the tree building based upon the result?

Currently there is the option of validation_fraction along with n_iter_no_change available in the implementation
The issue with that approach:

  1. Cannot use k-fold cross-validation.
  2. Canont use custom metrics
    Currently I solve it using the following, which is kind of hacky
#create a gradient booster
gbc = GradientBoostingClassifier()

#define the metric function that you want to use to early stopping
def accuracy(y_true, y_preds):
    return #return the metric output here
#This class along with the monitor argument will enable early stopping
class early_stopping_gbc():
    def __init__(self, accuracy, eval_set, early_stopping_rounds = 20):
        self.accuracy = accuracy
        self.x_val = eval_set[0]
        self.y_val = eval_set[1]
        self.best_perf = 0.
        self.counter = 0
        self.early_stopping_rounds = early_stopping_rounds
    def __call__(self,i, model, local_vars):
        for counter, preds in enumerate(model.staged_predict_proba(self.x_val)):
            if counter == i:
                break
        acc = self.accuracy(self.y_val,preds[:,1])
        if acc > self.best_perf:
            self.best_perf = acc
            self.counter = 0
        else:
            self.counter += 1
        return self.counter > self.early_stopping_rounds

#Run gradient booster with early stopping on 20 rounds
gbc.fit(X_train,y_train, monitor = early_stopping_gbc(accuracy, [X_val,y_val], early_stopping_rounds = 20))
@github-actions github-actions bot added the Needs Triage Issue requires triage label Mar 15, 2023
@ogrisel
Copy link
Member

ogrisel commented Mar 24, 2023

This seems related to #25460 where we need a way to pass a custom validation set.

We need to design a consistent API across different scikit-learn estimator to handle this need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants