Skip to content

Commit

Permalink
Added BestBinaryCategoricalSplit splitter for DecisionTree.
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolay-apanasov committed Mar 4, 2024
1 parent c2433cb commit 3a496d8
Show file tree
Hide file tree
Showing 16 changed files with 18,417 additions and 118 deletions.
12 changes: 7 additions & 5 deletions src/mlpack/methods/decision_tree/all_categorical_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,23 @@ class AllCategoricalSplit
const WeightVecType& weights,
const size_t minimumLeafSize,
const double minimumGainSplit,
double& splitInfo,
arma::vec& splitInfo,
AuxiliarySplitInfo& aux,
FitnessFunction& fitnessFunction);

/**
* Return the number of children in the split.
* If a split was found, returns the number of children of the split.
* Otherwise if there was no split, returns zero.
*
* @param splitInfo Auxiliary information for the split.
* @param * (aux) Auxiliary information for the split (Unused).
*/
static size_t NumChildren(const double& splitInfo,
static size_t NumChildren(const arma::vec& splitInfo,
const AuxiliarySplitInfo& /* aux */);

/**
* Calculate the direction a point should percolate to.
* If a split was found, given a point, calculates the index of the child
* it should go to. Otherwise if there was no split, returns SIZE_MAX.
*
* @param point the Point to use.
* @param splitInfo Auxiliary information for the split.
Expand All @@ -126,7 +128,7 @@ class AllCategoricalSplit
template<typename ElemType>
static size_t CalculateDirection(
const ElemType& point,
const double& splitInfo,
const arma::vec& splitInfo,
const AuxiliarySplitInfo& /* aux */);
};

Expand Down
13 changes: 7 additions & 6 deletions src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
const WeightVecType& weights,
const size_t minimumLeafSize,
const double minimumGainSplit,
double& splitInfo,
arma::vec& splitInfo,
AuxiliarySplitInfo& /* aux */,
FitnessFunction& fitnessFunction)
{
Expand Down Expand Up @@ -199,7 +199,8 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
if (overallGain > bestGain + minimumGainSplit + epsilon)
{
// This is better, so store it in splitInfo and return.
splitInfo = numCategories;
splitInfo.set_size(1);
splitInfo[0] = numCategories;
return overallGain;
}

Expand All @@ -209,20 +210,20 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(

template<typename FitnessFunction>
size_t AllCategoricalSplit<FitnessFunction>::NumChildren(
const double& splitInfo,
const arma::vec& splitInfo,
const AuxiliarySplitInfo& /* aux */)
{
return (size_t) splitInfo;
return splitInfo.n_elem == 0 ? 0 : (size_t) splitInfo[0];
}

template<typename FitnessFunction>
template<typename ElemType>
size_t AllCategoricalSplit<FitnessFunction>::CalculateDirection(
const ElemType& point,
const double& /* splitInfo */,
const arma::vec& splitInfo,
const AuxiliarySplitInfo& /* aux */)
{
return (size_t) point;
return splitInfo.n_elem == 0 ? SIZE_MAX : (size_t) point;
}

} // namespace mlpack
Expand Down
299 changes: 299 additions & 0 deletions src/mlpack/methods/decision_tree/best_binary_categorical_split.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
/**
* @file methods/decision_tree/best_binary_categorical_split.hpp
* @author Nikolay Apanasov (nikolay@apanasov.org)
*
* A tree splitter that finds the best binary categorical split.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_CATEGORICAL_SPLIT_HPP
#define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_CATEGORICAL_SPLIT_HPP
#define LEFT 0
#define RIGHT 1

#include <mlpack/prereqs.hpp>
using namespace arma;

/**
* The BestBinaryCategoricalSplit is a splitting function for decision trees
* that will exhaustively search a categorical dimension for the best binary
* split of a variable vₖ. This is a generic splitting strategy and can be
* used for both regression and classification.
*
* In the case of binary outcomes, it shown in CART[4.2] by Breiman et al.
* that if we order the categories by the proportion that fall in class C₁,
* and then split vₖ as if it was a numeric type, then the result is optimal.
* Surprising, but true. In the case of multiple classes, there is no such
* simplification. This method will search through all the 2ʲ possible
* partitions (Gₗ, Gᵣ) of the categories C₀, ..., Cⱼ₋₁, every time assigning
* samples with vₖ ∈ Gₗ to left tree Tₗ and those with vₖ ∈ Gᵣ to right
* tree Tᵣ.
*
* Warning: in the classification setting with multiple outcomes, this
* algorithm is exponential in the number of categories. Therefore
* BestBinaryCategoricalSplit should not be chosen when there are multiple
* classes and many categories.
*
* @book{CART,
* author = {Breiman, L. and Friedman, J. and Olshen, R. and Stone, C.},
* year = {1984},
* title = {Classification and Regression Trees},
* publisher = {Chapman \& Hall}
* }
*
* In the regression setting, the algorithm is similar to the preceding linear-
* time split for the case of binary outcomes. The correctness of the algorithm
* for a quantitative response under l₂ loss is due to Fisher.
*
* @article{Fisher58,
* author = {Fisher, W.},
* year = {1958},
* title = {On Grouping for Maximum Homogeniety},
* journal = {Journal of the American Statistical Association},
* volume = {53},
* pages = {789–798}
* }
*
* @tparam FitnessFunction Fitness function to use to calculate gain.
* categorical variable in the case of binary outcomes or regression.
*/
namespace mlpack {

template<typename FitnessFunction>
class BestBinaryCategoricalSplit
{

public:
// No extra info needed for split.
class AuxiliarySplitInfo { };
// Allow access to the numeric split type.
typedef BestBinaryNumericSplit<FitnessFunction> NumericSplit;
// For calls to the numeric splitter.
typedef typename BestBinaryNumericSplit<FitnessFunction>
::AuxiliarySplitInfo NumericAux;

/**
* Check if we can split a node. If we can split a node in a way that
* improves on bestGain, then we return the improved gain. Otherwise we
* return the value DBL_MAX.
*
* This overload is used only for classification.
*
* @param bestGain Best gain seen so far (we'll only split if we find gain
* better than this).
* @param data The dimension of data points to check for a split in.
* @param numCategories Number of categories in the categorical data.
* @param labels Labels for each point.
* @param numClasses Number of classes in the dataset.
* @param weights Weights associated with labels.
* @param minLeafSize min number of points in a leaf node for
* splitting.
* @param minGainSplit min gain split.
* @param splitInfo Stores split information on a succesful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
* @param aux (ignored)
*/
template<bool UseWeights, typename VecType, typename LabelsType,
typename WeightVecType>
static double SplitIfBetter(
const double bestGain,
const VecType& data,
const size_t numCategories,
const LabelsType& labels,
const size_t numClasses,
const WeightVecType& weights,
const size_t minLeafSize,
const double minGainSplit,
arma::vec& splitInfo,
AuxiliarySplitInfo& aux);

/**
* Check if we can split a node. If we can split a node in a way that
* improves on bestGain, then we return the improved gain. Otherwise we
* return the value DBL_MAX.
*
* Overload for regression. As mentioned above, the result of Fisher only
* applies under l₂ loss, and thus this overload is used only for regression
* with MSEGain.
*
* @param bestGain Best gain seen so far (we'll only split if we find gain
* better than this).
* @param data The dimension of data points to check for a split in.
* @param numCategories Number of categories in the categorical data.
* @param responses Responses for each point.
* @param weights Weights associated with responses.
* @param minLeafSize min number of points in a leaf node for
* splitting.
* @param minGainSplit min gain split.
* @param splitInfo Stores split information on a successful split.
*
* @param splitInfo Stores split information on a succesful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
* @param aux (ignored)
* @param fitnessFunction The FitnessFunction object instance. It it used to
* evaluate the gain for the split.
*/
template<bool UseWeights, typename VecType, typename ResponsesType,
typename WeightVecType>
static double SplitIfBetter(
const double bestGain,
const VecType& data,
const size_t numCategories,
const ResponsesType& responses,
const WeightVecType& weights,
const size_t minLeafSize,
const double minGainSplit,
arma::vec& splitInfo,
AuxiliarySplitInfo& aux,
FitnessFunction& fitnessFunction);

/**
* In the case that a split was found, returns the number of children
* of the split. Otherwise if there was not split, returns zero. A binary
* split always has two children.
*
* @param splitInfo Auxiliary information for the split. A vector
* of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
* @param * (aux) Auxiliary information for the split (Unused).
*/
static size_t NumChildren(const arma::vec& splitInfo,
const AuxiliarySplitInfo& /* aux */)
{
return splitInfo.n_elem == 0 ? 0 : 2;
}

/**
*
* In the case that a split was found, given a point, calculates
* the index of the child it should go to. Otherwise if there was
* no split, returns SIZE_MAX.
*
* @param point the Point to use.
* @param splitInfo Auxiliary information for the split. A vector
* of size J, where J is the number of categories. splitInfo[k] is
* zero if category Cₖ is assigned to the left child, and otherwise
* it is one if Cₖ is assigned to the right.
* @param * (aux) Auxiliary information for the split (Unused).
*/
template<typename ElemType>
static size_t CalculateDirection(
const ElemType& point,
const arma::vec& splitInfo,
const AuxiliarySplitInfo& /* aux */)
{
return splitInfo.n_elem == 0 ? SIZE_MAX : (size_t) splitInfo[point];
}

private:
/**
* Auxiliary for SplitIfBetter in the multi-class setting. Recursively
* enumerates all partitions (Gₗ, Gᵣ) of categories C₀, ..., Cⱼ₋₁, and
* computes the gain for each one, where samples with vₖ ∈ Gₗ are assigned
* to the left tree Tₗ and those with vₖ ∈ Gᵣ to the right tree Tᵣ.
*
* In the case that a better split is found, bestFoundGain is updated with
* the gain value and splitInfo is updated with the corresponding partition.
*
* @param labels -- Labels for each point.
* @param numClasses -- Number of classes in the dataset.
* @param numCategories Number of categories in the categorical data.
* @param bestFoundGain -- The best gain found thus far. Updated if
* and when a better split is found.
* @param categorySamples -- Map from category Cⱼ to the samples whose
* categorical value for variable vₖ is Cⱼ. Column j is for Cⱼ.
* @param categories -- J dimensional vector used to maintain the
* current partition of the categories.
* @param splitInfo -- Stores split information on a succesful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
* @param classCounts -- mx2 matrix, where m = numClasses, used to compute
* the gain with the FitnessFunction. All zero initially.
* @param totalLeft, totalRight -- Number of samples assigned
* to the left and right subtrees respectively. Initialized to zero.
* @param k -- Index of the current category being assigned.
* Initialized value is zero.
*/
template<typename VecType, typename LabelsType>
static bool PartitionSplit(
const VecType& data,
const LabelsType& labels,
const size_t numCategories,
const size_t numClasses,
double& bestFoundGain,
arma::SpMat<short>& categorySamples,
arma::uvec& categories,
arma::vec& splitInfo,
arma::Mat<size_t>& classCounts,
size_t totalLeft = 0,
size_t totalRight = 0,
size_t k = 0);

/**
* Auxiliary for SplitIfBetter in the multi-class setting. Recursively
* enumerates all partitions (Gₗ, Gᵣ) of categories C₀, ..., Cⱼ₋₁, and
* computes the gain for each one, where samples with vₖ ∈ Gₗ are assigned
* to the left tree Tₗ and those with vₖ ∈ Gᵣ to the right tree Tᵣ.
*
* In the case that a better split is found, bestFoundGain is updated with
* the gain value and splitInfo is updated with the corresponding partition.
*
* This overload is used to compute the partition using weights, that is
* when the template variable UseWeights is true.
*
* @param labels -- Labels for each point.
* @param numCategories Number of categories in the categorical data.
* @param numClasses -- Number of classes in the dataset.
* @param weights -- Weights associated with labels.
* @param totalWeight -- Sum of weights.
* @param bestFoundGain -- The best gain found thus far. Updated if
* and when a better split is found.
* @param categorySamples -- Map from category Cⱼ to the samples whose
* categorical value for variable vₖ is Cⱼ. Column j is for Cⱼ.
* @param categories -- J dimensional vector used to maintain the
* current partition of the categories.
* @param splitInfo -- Stores split information on a succesful split. A
* vector of size J, where J is the number of categories. splitInfo[k]
* is zero if category k is assigned to the left child, and otherwise
* it is one if assigned to the right.
* @param classWeightSums -- mx2 matrix, where m = numClasses, used to
* compute the gain with the FitnessFunction. All zero initially.
* @param totalLeftWeight, totalRightWeight -- Weight assigned to the left
* and right subtree respectively. Initialized to zero.
* @param k -- Index of the current category being assigned.
* Initialized value is zero.
*/
template<typename VecType, typename LabelsType, typename WeightVecType>
static bool PartitionSplit(
const VecType& data,
const LabelsType& labels,
const size_t numCategories,
const size_t numClasses,
const WeightVecType& weights,
const double totalWeight,
double& bestFoundGain,
arma::SpMat<short>& categorySamples,
arma::uvec& categories,
arma::vec& splitInfo,
arma::mat& classWeightSums,
double totalLeftWeight = 0.0,
double totalRightWeight = 0.0,
size_t k = 0);

};

} // namespace mlpack

// Include implementation.
#include "best_binary_categorical_split_impl.hpp"

#endif

0 comments on commit 3a496d8

Please sign in to comment.