Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gradient boosting binding #3669

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

TirelessClock
Copy link
Contributor

I noticed that mlpack doesn't currently have a binding for Gradient Boosting.

To give a brief description of Gradient Boosting: it's an ensemble technique which uses weak learners such as shallow Decision Trees (or Decision stumps). Each learner is trained to predict the errors, or gradients, of previous learners, and their subsequent result is combined to give the final result.

Gradient Boosting is a very common and very powerful ensemble technique used for a wide variety of reasons, from classification to regression. More significantly, Gradient Boosting is a pathway to implementing more powerful algorithms, primarily XGBoost, which is just an optimized version of Gradient Boosting. For these reasons I implemented Gradient Boosting's binding.

I've implemented Gradient Boosting in mlpack/src/mlpack/methods/grad_boosting.

Files added:

  • grad_boosting_main.cpp: Main binding file.
  • grad_boosting.hpp: The base grad_boosting class. Only has the interface for the methods, not their implementation.
  • grad_boosting_impl.hpp: Implementation of the grad_boosting class.
  • grad_boosting_model.hpp: The serializable grad_boosting model class. Only has the interface for the methods, not their implementation.
  • grad_boosting_model_impl.hpp: Implementation of serializable grad_boosting model.

Please note: this PR is still incomplete and I have a few more changes before it can be finalized, but I've made the PR now so I can include it in my GSoC'24 proposal.

@shrit
Copy link
Member

shrit commented Apr 19, 2024

@TirelessClock the code looks good, I skimmed the implementation rapidly, but I will go more into detail.
Sorry for my slow review on this one. Could you add a couple more of tests ?

@TirelessClock
Copy link
Contributor Author

@shrit Yeah sure. Thanks for the review!

@rcurtin
Copy link
Member

rcurtin commented Apr 19, 2024

If we are going to add a C++ class for it too, do you think you can write some documentation for it? Take a look at doc/user/methods/decision_tree.md for an example; the structure can be adapted.

@TirelessClock
Copy link
Contributor Author

@rcurtin Yeah, I'll start work on the documentation + testing immediately.

Copy link
Member

@shrit shrit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TirelessClock please apply mlpack code style, and add the tests and documentations as discussed.
For more information about the style check the wiki in github, or check other implementation code, which I think it is going to be faster.

Comment on lines +99 to +103
void Classify(
const arma::mat& testData,
arma::Row<size_t>& predictions,
arma::mat& probabilities
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow mlpack style in here, regarding the parameters.

Comment on lines +93 to +95
void Classify(
const arma::mat& testData,
arma::Row<size_t>& predictions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TirelessClock any reason why you did not add MatType template to this function ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shrit No specific reason. I used MatType in grad_boosting, but used barebone arm::mat in the grad_boosting_model. I guess I made the mistake once and just went with it without realizing.

Comment on lines +107 to +111
void serialize(Archive& ar, const uint32_t /* version */) {
if (cereal::is_loading<Archive>()) {
delete dsBoost;
dsBoost = NULL;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here, the { goes on a new line after if statement or a function implementation

@TirelessClock
Copy link
Contributor Author

@TirelessClock please apply mlpack code style, and add the tests and documentations as discussed. For more information about the style check the wiki in github, or check other implementation code, which I think it is going to be faster.

Hello @shrit , I'm so sorry about the delay, currently a bit overloaded with assignments and exams. I'll finish this up and push by the end of the week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants