Skip to content

Commit

Permalink
Add updater parameter to by able to set shotgun
Browse files Browse the repository at this point in the history
  • Loading branch information
valenad1 committed Apr 19, 2024
1 parent 39c6000 commit efc644a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public enum Backend {
public enum FeatureSelector {
cyclic, shuffle, random, greedy, thrifty
}
public enum Updater {
gpu_hist, shotgun, coord_descent, gpu_coord_descent,
}

// H2O GBM options
public boolean _quiet_mode = true;
Expand Down Expand Up @@ -148,6 +151,7 @@ public enum FeatureSelector {
// lambda, alpha support also for gbtree
public FeatureSelector _feature_selector = FeatureSelector.cyclic;
public int _top_k;
public Updater _updater;

public String _eval_metric;
public boolean _score_eval_metric_only;
Expand Down Expand Up @@ -399,17 +403,17 @@ public static Map<String, Object> createParamsMap(XGBoostParameters p, int nClas
params.put("gpu_id", 0);
}
// we are setting updater rather than tree_method here to keep CPU predictor, which is faster
if (p._booster == XGBoostParameters.Booster.gblinear) {
if (p._booster == XGBoostParameters.Booster.gblinear && p._updater == null) {
LOG.info("Using gpu_coord_descent updater.");
params.put("updater", "gpu_coord_descent");
params.put("updater", XGBoostParameters.Updater.gpu_coord_descent.toString());
} else {
LOG.info("Using gpu_hist tree method.");
params.put("max_bin", p._max_bins);
params.put("tree_method", "gpu_hist");
params.put("tree_method", XGBoostParameters.Updater.gpu_hist.toString());
}
} else if (p._booster == XGBoostParameters.Booster.gblinear) {
} else if (p._booster == XGBoostParameters.Booster.gblinear && p._updater == null) {
LOG.info("Using coord_descent updater.");
params.put("updater", "coord_descent");
params.put("updater", XGBoostParameters.Updater.coord_descent.toString());
} else if (H2O.CLOUD.size() > 1 && p._tree_method == XGBoostParameters.TreeMethod.auto &&
p._monotone_constraints != null) {
LOG.info("Using hist tree method for distributed computation with monotone_constraints.");
Expand All @@ -422,6 +426,10 @@ public static Map<String, Object> createParamsMap(XGBoostParameters p, int nClas
params.put("max_bin", p._max_bins);
}
}
if (p._updater != null) {
LOG.info("Using user-provided updater.");
params.put("updater", p._updater.toString());
}
if (p._min_child_weight != 1) {
LOG.info("Using user-provided parameter min_child_weight instead of min_rows.");
params.put("min_child_weight", p._min_child_weight);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3320,5 +3320,32 @@ public void testGBLinearTopKAndFeatureSelector() {
Scope.exit();
}
}



@Test
public void testGBLinearShotgun() {
Scope.enter();
try {
String response = "CAPSULE";
Frame train = parseAndTrackTestFile("./smalldata/logreg/prostate_train.csv");
train.toCategoricalCol(response);

XGBoostModel.XGBoostParameters parms = new XGBoostModel.XGBoostParameters();
parms._ntrees = 1;
parms._train = train._key;
parms._response_column = response;
parms._booster = XGBoostModel.XGBoostParameters.Booster.gblinear;
parms._updater = XGBoostModel.XGBoostParameters.Updater.shotgun;
parms._feature_selector = XGBoostModel.XGBoostParameters.FeatureSelector.shuffle;

ModelBuilder job = new hex.tree.xgboost.XGBoost(parms);
XGBoostModel xgboost = (XGBoostModel) job.trainModel().get();
assertNotNull(xgboost);
Scope.track_generic(xgboost);
assertEquals("updater should be changed", xgboost._output._native_parameters.get(1,1), XGBoostModel.XGBoostParameters.Updater.shotgun.toString());
}
finally {
Scope.exit();
}
}
}

0 comments on commit efc644a

Please sign in to comment.