Skip to content

Commit

Permalink
Update xgb_classification.py
Browse files Browse the repository at this point in the history
We fixed the bug that the input column names are mismatch when the input is some columns, not array
  • Loading branch information
jen-choi committed Jan 7, 2020
1 parent 2dc8654 commit 9965ef9
Showing 1 changed file with 3 additions and 1 deletion.
Expand Up @@ -58,7 +58,9 @@ def _xgb_classification_train(table, feature_cols, label_col, max_depth=3, learn
class_weight=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True,
xgb_model=None, sample_weight_eval_set=None):
feature_names, features = check_col_type(table, feature_cols)
features = np.array(features)

if isinstance(features, list):
features = np.array(features)

if random_state is None:
random_state = randint(-2**31, 2**31-1)
Expand Down

0 comments on commit 9965ef9

Please sign in to comment.