diff --git a/function/python/brightics/function/classification/xgb_classification.py b/function/python/brightics/function/classification/xgb_classification.py index 2c665ad17..7c9ef3155 100644 --- a/function/python/brightics/function/classification/xgb_classification.py +++ b/function/python/brightics/function/classification/xgb_classification.py @@ -58,7 +58,9 @@ def _xgb_classification_train(table, feature_cols, label_col, max_depth=3, learn class_weight=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, xgb_model=None, sample_weight_eval_set=None): feature_names, features = check_col_type(table, feature_cols) - features = np.array(features) + + if isinstance(features, list): + features = np.array(features) if random_state is None: random_state = randint(-2**31, 2**31-1)