Skip to content

Commit

Permalink
Add updater parameter to by able to set shotgun
Browse files Browse the repository at this point in the history
  • Loading branch information
valenad1 committed Apr 19, 2024
1 parent 39c6000 commit de06a76
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public enum Backend {
public enum FeatureSelector {
cyclic, shuffle, random, greedy, thrifty
}
public enum Updater {
gpu_hist, shotgun, coord_descent, gpu_coord_descent,
}

// H2O GBM options
public boolean _quiet_mode = true;
Expand Down Expand Up @@ -148,6 +151,7 @@ public enum FeatureSelector {
// lambda, alpha support also for gbtree
public FeatureSelector _feature_selector = FeatureSelector.cyclic;
public int _top_k;
public Updater _updater;

public String _eval_metric;
public boolean _score_eval_metric_only;
Expand Down Expand Up @@ -399,17 +403,17 @@ public static Map<String, Object> createParamsMap(XGBoostParameters p, int nClas
params.put("gpu_id", 0);
}
// we are setting updater rather than tree_method here to keep CPU predictor, which is faster
if (p._booster == XGBoostParameters.Booster.gblinear) {
if (p._booster == XGBoostParameters.Booster.gblinear && p._updater == null) {
LOG.info("Using gpu_coord_descent updater.");
params.put("updater", "gpu_coord_descent");
params.put("updater", XGBoostParameters.Updater.gpu_coord_descent.toString());
} else {
LOG.info("Using gpu_hist tree method.");
params.put("max_bin", p._max_bins);
params.put("tree_method", "gpu_hist");
params.put("tree_method", XGBoostParameters.Updater.gpu_hist.toString());
}
} else if (p._booster == XGBoostParameters.Booster.gblinear) {
} else if (p._booster == XGBoostParameters.Booster.gblinear && p._updater == null) {
LOG.info("Using coord_descent updater.");
params.put("updater", "coord_descent");
params.put("updater", XGBoostParameters.Updater.coord_descent.toString());
} else if (H2O.CLOUD.size() > 1 && p._tree_method == XGBoostParameters.TreeMethod.auto &&
p._monotone_constraints != null) {
LOG.info("Using hist tree method for distributed computation with monotone_constraints.");
Expand All @@ -422,6 +426,10 @@ public static Map<String, Object> createParamsMap(XGBoostParameters p, int nClas
params.put("max_bin", p._max_bins);
}
}
if (p._updater != null) {
LOG.info("Using user-provided updater.");
params.put("updater", p._updater.toString());
}
if (p._min_child_weight != 1) {
LOG.info("Using user-provided parameter min_child_weight instead of min_rows.");
params.put("min_child_weight", p._min_child_weight);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public class XGBoostTest extends TestUtil {
@Parameterized.Parameters(name = "XGBoost(javaPredict={0}")
public static Collection<Object> data() {
return Arrays.asList(new Object[]{
"true", "false"
"true"
});
}

Expand Down

0 comments on commit de06a76

Please sign in to comment.