Skip to content

Commit

Permalink
Applied suggestions from code review.
Browse files Browse the repository at this point in the history
Co-authored-by: Ryan Curtin <ryan@ratml.org>
  • Loading branch information
nikolay-apanasov and rcurtin committed Apr 22, 2024
1 parent 43913ea commit cceb6f3
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 298 deletions.
29 changes: 21 additions & 8 deletions doc/user/methods/decision_tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ that is used should be the same type that was used for training.
[`data::Save()` and `data::Load()`](../load_save.md#mlpack-objects).
* `tree.NumChildren()` will return a `size_t` indicating the number of children
in the node `tree`.
in the node `tree`. If there was no split, zero is returned.
* `tree.Child(i)` will return a `DecisionTree` object representing the `i`th
child of the node `tree`.
Expand Down Expand Up @@ -426,10 +426,22 @@ class CustomNumericSplit
const double minGainSplit,
arma::vec& splitInfo,
AuxiliarySplitInfo& aux);

// Return the number of children for a given split (stored as the single
// element from `splitInfo` and auxiliary data `aux` in `SplitIfBetter()`).
size_t NumChildren(const double& splitInfo,
/**
* In the case that a split was found, returns the number of children
* of the split. Otherwise if there was not split, returns zero. A binary
* split always has two children.
*
* @param splitInfo Auxiliary information for the split. A vector
* of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
* @param * (aux) Auxiliary information for the split (Unused).
*/

// Return the number of children for a given split. If there was no split,
// return zero. `splitInfo` and `aux` contain the split information, as set
// in `SplitIfBetter`.
size_t NumChildren(const arma::vec& splitInfo,
const AuxiliarySplitInfo& aux);

// Given a point with value `point`, and split information `splitInfo` and
Expand Down Expand Up @@ -504,9 +516,10 @@ class CustomCategoricalSplit
arma::vec& splitInfo,
AuxiliarySplitInfo& aux);

// Return the number of children for a given split (stored as the single
// element from `splitInfo` and auxiliary data `aux` in `SplitIfBetter()`).
size_t NumChildren(const double& splitInfo,
// Return the number of children for a given split. If there was no split,
// return zero. `splitInfo` and `aux` contain the split information, as set
// in `SplitIfBetter`.
size_t NumChildren(const arma::vec& splitInfo,
const AuxiliarySplitInfo& aux);

// Given a point with (categorical) value `point`, and split information
Expand Down
14 changes: 8 additions & 6 deletions doc/user/methods/decision_tree_regressor.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,10 @@ class CustomNumericSplit
AuxiliarySplitInfo& aux,
FitnessFunction& function);

// Return the number of children for a given split (stored as the single
// element from `splitInfo` and auxiliary data `aux` in `SplitIfBetter()`).
size_t NumChildren(const double& splitInfo,
// Return the number of children for a given split. If there was no split,
// return zero. `splitInfo` and `aux` contain the split information, as set
// in `SplitIfBetter`.
size_t NumChildren(const arma::vec& splitInfo,
const AuxiliarySplitInfo& aux);

// Given a point with value `point`, and split information `splitInfo` and
Expand Down Expand Up @@ -503,9 +504,10 @@ class CustomCategoricalSplit
AuxiliarySplitInfo& aux,
FitnessFunction& fitnessFunction);

// Return the number of children for a given split (stored as the single
// element from `splitInfo` and auxiliary data `aux` in `SplitIfBetter()`).
size_t NumChildren(const double& splitInfo,
// Return the number of children for a given split. If there was no split,
// return zero. `splitInfo` and `aux` contain the split information, as set
// in `SplitIfBetter`.
size_t NumChildren(const arma::vec& splitInfo,
const AuxiliarySplitInfo& aux);

// Given a point with (categorical) value `point`, and split information
Expand Down
20 changes: 8 additions & 12 deletions src/mlpack/methods/decision_tree/best_binary_categorical_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
*/
#ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_CATEGORICAL_SPLIT_HPP
#define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_CATEGORICAL_SPLIT_HPP
#define LEFT 0
#define RIGHT 1

#include <mlpack/prereqs.hpp>
using namespace arma;

namespace mlpack {

/**
* The BestBinaryCategoricalSplit is a splitting function for decision trees
Expand Down Expand Up @@ -58,10 +57,7 @@ using namespace arma;
* }
*
* @tparam FitnessFunction Fitness function to use to calculate gain.
* categorical variable in the case of binary outcomes or regression.
*/
namespace mlpack {

template<typename FitnessFunction>
class BestBinaryCategoricalSplit
{
Expand All @@ -73,7 +69,7 @@ class BestBinaryCategoricalSplit
typedef BestBinaryNumericSplit<FitnessFunction> NumericSplit;
// For calls to the numeric splitter.
typedef typename BestBinaryNumericSplit<FitnessFunction>
::AuxiliarySplitInfo NumericAux;
::AuxiliarySplitInfo NumericAux;

/**
* Check if we can split a node. If we can split a node in a way that
Expand All @@ -92,7 +88,7 @@ class BestBinaryCategoricalSplit
* @param minLeafSize min number of points in a leaf node for
* splitting.
* @param minGainSplit min gain split.
* @param splitInfo Stores split information on a succesful split. A
* @param splitInfo Stores split information on a successful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
Expand Down Expand Up @@ -132,7 +128,7 @@ class BestBinaryCategoricalSplit
* @param minGainSplit min gain split.
* @param splitInfo Stores split information on a successful split.
*
* @param splitInfo Stores split information on a succesful split. A
* @param splitInfo Stores split information on a successful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
Expand All @@ -156,7 +152,7 @@ class BestBinaryCategoricalSplit

/**
* In the case that a split was found, returns the number of children
* of the split. Otherwise if there was not split, returns zero. A binary
* of the split. Otherwise if there was no split, returns zero. A binary
* split always has two children.
*
* @param splitInfo Auxiliary information for the split. A vector
Expand Down Expand Up @@ -212,7 +208,7 @@ class BestBinaryCategoricalSplit
* categorical value for variable vₖ is Cⱼ. Column j is for Cⱼ.
* @param categories -- J dimensional vector used to maintain the
* current partition of the categories.
* @param splitInfo -- Stores split information on a succesful split. A
* @param splitInfo -- Stores split information on a successful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
Expand Down Expand Up @@ -261,7 +257,7 @@ class BestBinaryCategoricalSplit
* categorical value for variable vₖ is Cⱼ. Column j is for Cⱼ.
* @param categories -- J dimensional vector used to maintain the
* current partition of the categories.
* @param splitInfo -- Stores split information on a succesful split. A
* @param splitInfo -- Stores split information on a successful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
Expand Down

0 comments on commit cceb6f3

Please sign in to comment.