1 #ifndef DECISION_TREE_CLASSIFIER_HPP 
    2 #define DECISION_TREE_CLASSIFIER_HPP 
   39     void fit(
const std::vector<std::vector<double>>& X, 
const std::vector<int>& y);
 
   46     std::vector<int> 
predict(
const std::vector<std::vector<double>>& X) 
const;
 
   57         Node() : is_leaf(
false), value(0), feature_index(-1), threshold(0.0), left(
nullptr), right(
nullptr) {}
 
   62     int min_samples_split;
 
   64     Node* build_tree(
const std::vector<std::vector<double>>& X, 
const std::vector<int>& y, 
int depth);
 
   65     double calculate_gini(
const std::vector<int>& y) 
const;
 
   66     void split_dataset(
const std::vector<std::vector<double>>& X, 
const std::vector<int>& y, 
int feature_index, 
double threshold,
 
   67                        std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
 
   68                        std::vector<std::vector<double>>& X_right, std::vector<int>& y_right) 
const;
 
   69     int predict_sample(
const std::vector<double>& x, Node* node) 
const;
 
   70     void delete_tree(Node* node);
 
   74     : root(nullptr), max_depth(max_depth), min_samples_split(min_samples_split) {}
 
   81     root = build_tree(X, y, 0);
 
   85     std::vector<int> predictions;
 
   86     for (
const auto& x : X) {
 
   87         predictions.push_back(predict_sample(x, root));
 
   92 DecisionTreeClassifier::Node* DecisionTreeClassifier::build_tree(
const std::vector<std::vector<double>>& X, 
const std::vector<int>& y, 
int depth) {
 
   93     Node* node = 
new Node();
 
   96     if (depth >= max_depth || y.size() < 
static_cast<size_t>(min_samples_split) || calculate_gini(y) == 0.0) {
 
   99         std::map<int, int> class_counts;
 
  100         for (
int label : y) {
 
  101             class_counts[label]++;
 
  103         node->value = std::max_element(class_counts.begin(), class_counts.end(),
 
  104                                        [](
const std::pair<int, int>& a, 
const std::pair<int, int>& b) {
 
  105                                            return a.second < b.second;
 
  110     double best_gini = std::numeric_limits<double>::max();
 
  111     int best_feature_index = -1;
 
  112     double best_threshold = 0.0;
 
  113     std::vector<std::vector<double>> best_X_left, best_X_right;
 
  114     std::vector<int> best_y_left, best_y_right;
 
  116     int num_features = X[0].size();
 
  117     for (
int feature_index = 0; feature_index < num_features; ++feature_index) {
 
  119         std::vector<double> feature_values;
 
  120         for (
const auto& x : X) {
 
  121             feature_values.push_back(x[feature_index]);
 
  123         std::sort(feature_values.begin(), feature_values.end());
 
  124         std::vector<double> thresholds;
 
  125         for (
size_t i = 1; i < feature_values.size(); ++i) {
 
  126             thresholds.push_back((feature_values[i - 1] + feature_values[i]) / 2.0);
 
  130         for (
double threshold : thresholds) {
 
  131             std::vector<std::vector<double>> X_left, X_right;
 
  132             std::vector<int> y_left, y_right;
 
  133             split_dataset(X, y, feature_index, threshold, X_left, y_left, X_right, y_right);
 
  135             if (y_left.empty() || y_right.empty())
 
  138             double gini_left = calculate_gini(y_left);
 
  139             double gini_right = calculate_gini(y_right);
 
  140             double gini = (gini_left * y_left.size() + gini_right * y_right.size()) / y.size();
 
  142             if (gini < best_gini) {
 
  144                 best_feature_index = feature_index;
 
  145                 best_threshold = threshold;
 
  146                 best_X_left = X_left;
 
  147                 best_X_right = X_right;
 
  148                 best_y_left = y_left;
 
  149                 best_y_right = y_right;
 
  155     if (best_feature_index == -1) {
 
  156         node->is_leaf = 
true;
 
  158         std::map<int, int> class_counts;
 
  159         for (
int label : y) {
 
  160             class_counts[label]++;
 
  162         node->value = std::max_element(class_counts.begin(), class_counts.end(),
 
  163                                        [](
const std::pair<int, int>& a, 
const std::pair<int, int>& b) {
 
  164                                            return a.second < b.second;
 
  170     node->feature_index = best_feature_index;
 
  171     node->threshold = best_threshold;
 
  172     node->left = build_tree(best_X_left, best_y_left, depth + 1);
 
  173     node->right = build_tree(best_X_right, best_y_right, depth + 1);
 
  177 double DecisionTreeClassifier::calculate_gini(
const std::vector<int>& y)
 const {
 
  178     std::map<int, int> class_counts;
 
  179     for (
int label : y) {
 
  180         class_counts[label]++;
 
  182     double impurity = 1.0;
 
  183     size_t total = y.size();
 
  184     for (
const auto& class_count : class_counts) {
 
  185         double prob = 
static_cast<double>(class_count.second) / total;
 
  186         impurity -= prob * prob;
 
  191 void DecisionTreeClassifier::split_dataset(
const std::vector<std::vector<double>>& X, 
const std::vector<int>& y,
 
  192                                            int feature_index, 
double threshold,
 
  193                                            std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
 
  194                                            std::vector<std::vector<double>>& X_right, std::vector<int>& y_right)
 const {
 
  195     for (
size_t i = 0; i < X.size(); ++i) {
 
  196         if (X[i][feature_index] <= threshold) {
 
  197             X_left.push_back(X[i]);
 
  198             y_left.push_back(y[i]);
 
  200             X_right.push_back(X[i]);
 
  201             y_right.push_back(y[i]);
 
  206 int DecisionTreeClassifier::predict_sample(
const std::vector<double>& x, Node* node)
 const {
 
  210     if (x[node->feature_index] <= node->threshold) {
 
  211         return predict_sample(x, node->left);
 
  213         return predict_sample(x, node->right);
 
  217 void DecisionTreeClassifier::delete_tree(Node* node) {
 
  218     if (node != 
nullptr) {
 
  219         delete_tree(node->left);
 
  220         delete_tree(node->right);
 
Implements a Decision Tree Classifier.
Definition: DecisionTreeClassifier.hpp:20
DecisionTreeClassifier(int max_depth=5, int min_samples_split=2)
Constructs a DecisionTreeClassifier.
Definition: DecisionTreeClassifier.hpp:73
std::vector< int > predict(const std::vector< std::vector< double >> &X) const
Predicts class labels for given input data.
Definition: DecisionTreeClassifier.hpp:84
void fit(const std::vector< std::vector< double >> &X, const std::vector< int > &y)
Fits the model to the training data.
Definition: DecisionTreeClassifier.hpp:80
~DecisionTreeClassifier()
Destructor for DecisionTreeClassifier.
Definition: DecisionTreeClassifier.hpp:76