Cpp ML Library  1.0.0
A library of Machine Learning Algorithmns seen from the Udemy course Machine Learning A to Z.
DecisionTreeClassifier.hpp
Go to the documentation of this file.
1 #ifndef DECISION_TREE_CLASSIFIER_HPP
2 #define DECISION_TREE_CLASSIFIER_HPP
3 
4 #include <vector>
5 #include <algorithm>
6 #include <numeric>
7 #include <limits>
8 #include <map>
9 #include <cmath>
10 
21 public:
27  DecisionTreeClassifier(int max_depth = 5, int min_samples_split = 2);
28 
33 
39  void fit(const std::vector<std::vector<double>>& X, const std::vector<int>& y);
40 
46  std::vector<int> predict(const std::vector<std::vector<double>>& X) const;
47 
48 private:
49  struct Node {
50  bool is_leaf;
51  int value; // Class label for leaf nodes
52  int feature_index;
53  double threshold;
54  Node* left;
55  Node* right;
56 
57  Node() : is_leaf(false), value(0), feature_index(-1), threshold(0.0), left(nullptr), right(nullptr) {}
58  };
59 
60  Node* root;
61  int max_depth;
62  int min_samples_split;
63 
64  Node* build_tree(const std::vector<std::vector<double>>& X, const std::vector<int>& y, int depth);
65  double calculate_gini(const std::vector<int>& y) const;
66  void split_dataset(const std::vector<std::vector<double>>& X, const std::vector<int>& y, int feature_index, double threshold,
67  std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
68  std::vector<std::vector<double>>& X_right, std::vector<int>& y_right) const;
69  int predict_sample(const std::vector<double>& x, Node* node) const;
70  void delete_tree(Node* node);
71 };
72 
73 DecisionTreeClassifier::DecisionTreeClassifier(int max_depth, int min_samples_split)
74  : root(nullptr), max_depth(max_depth), min_samples_split(min_samples_split) {}
75 
77  delete_tree(root);
78 }
79 
80 void DecisionTreeClassifier::fit(const std::vector<std::vector<double>>& X, const std::vector<int>& y) {
81  root = build_tree(X, y, 0);
82 }
83 
84 std::vector<int> DecisionTreeClassifier::predict(const std::vector<std::vector<double>>& X) const {
85  std::vector<int> predictions;
86  for (const auto& x : X) {
87  predictions.push_back(predict_sample(x, root));
88  }
89  return predictions;
90 }
91 
92 DecisionTreeClassifier::Node* DecisionTreeClassifier::build_tree(const std::vector<std::vector<double>>& X, const std::vector<int>& y, int depth) {
93  Node* node = new Node();
94 
95  // Check stopping criteria
96  if (depth >= max_depth || y.size() < static_cast<size_t>(min_samples_split) || calculate_gini(y) == 0.0) {
97  node->is_leaf = true;
98  // Majority class label
99  std::map<int, int> class_counts;
100  for (int label : y) {
101  class_counts[label]++;
102  }
103  node->value = std::max_element(class_counts.begin(), class_counts.end(),
104  [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
105  return a.second < b.second;
106  })->first;
107  return node;
108  }
109 
110  double best_gini = std::numeric_limits<double>::max();
111  int best_feature_index = -1;
112  double best_threshold = 0.0;
113  std::vector<std::vector<double>> best_X_left, best_X_right;
114  std::vector<int> best_y_left, best_y_right;
115 
116  int num_features = X[0].size();
117  for (int feature_index = 0; feature_index < num_features; ++feature_index) {
118  // Get all possible thresholds
119  std::vector<double> feature_values;
120  for (const auto& x : X) {
121  feature_values.push_back(x[feature_index]);
122  }
123  std::sort(feature_values.begin(), feature_values.end());
124  std::vector<double> thresholds;
125  for (size_t i = 1; i < feature_values.size(); ++i) {
126  thresholds.push_back((feature_values[i - 1] + feature_values[i]) / 2.0);
127  }
128 
129  // Evaluate each threshold
130  for (double threshold : thresholds) {
131  std::vector<std::vector<double>> X_left, X_right;
132  std::vector<int> y_left, y_right;
133  split_dataset(X, y, feature_index, threshold, X_left, y_left, X_right, y_right);
134 
135  if (y_left.empty() || y_right.empty())
136  continue;
137 
138  double gini_left = calculate_gini(y_left);
139  double gini_right = calculate_gini(y_right);
140  double gini = (gini_left * y_left.size() + gini_right * y_right.size()) / y.size();
141 
142  if (gini < best_gini) {
143  best_gini = gini;
144  best_feature_index = feature_index;
145  best_threshold = threshold;
146  best_X_left = X_left;
147  best_X_right = X_right;
148  best_y_left = y_left;
149  best_y_right = y_right;
150  }
151  }
152  }
153 
154  // If no split improves the Gini impurity, make this a leaf node
155  if (best_feature_index == -1) {
156  node->is_leaf = true;
157  // Majority class label
158  std::map<int, int> class_counts;
159  for (int label : y) {
160  class_counts[label]++;
161  }
162  node->value = std::max_element(class_counts.begin(), class_counts.end(),
163  [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
164  return a.second < b.second;
165  })->first;
166  return node;
167  }
168 
169  // Recursively build the left and right subtrees
170  node->feature_index = best_feature_index;
171  node->threshold = best_threshold;
172  node->left = build_tree(best_X_left, best_y_left, depth + 1);
173  node->right = build_tree(best_X_right, best_y_right, depth + 1);
174  return node;
175 }
176 
177 double DecisionTreeClassifier::calculate_gini(const std::vector<int>& y) const {
178  std::map<int, int> class_counts;
179  for (int label : y) {
180  class_counts[label]++;
181  }
182  double impurity = 1.0;
183  size_t total = y.size();
184  for (const auto& class_count : class_counts) {
185  double prob = static_cast<double>(class_count.second) / total;
186  impurity -= prob * prob;
187  }
188  return impurity;
189 }
190 
191 void DecisionTreeClassifier::split_dataset(const std::vector<std::vector<double>>& X, const std::vector<int>& y,
192  int feature_index, double threshold,
193  std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
194  std::vector<std::vector<double>>& X_right, std::vector<int>& y_right) const {
195  for (size_t i = 0; i < X.size(); ++i) {
196  if (X[i][feature_index] <= threshold) {
197  X_left.push_back(X[i]);
198  y_left.push_back(y[i]);
199  } else {
200  X_right.push_back(X[i]);
201  y_right.push_back(y[i]);
202  }
203  }
204 }
205 
206 int DecisionTreeClassifier::predict_sample(const std::vector<double>& x, Node* node) const {
207  if (node->is_leaf) {
208  return node->value;
209  }
210  if (x[node->feature_index] <= node->threshold) {
211  return predict_sample(x, node->left);
212  } else {
213  return predict_sample(x, node->right);
214  }
215 }
216 
217 void DecisionTreeClassifier::delete_tree(Node* node) {
218  if (node != nullptr) {
219  delete_tree(node->left);
220  delete_tree(node->right);
221  delete node;
222  }
223 }
224 
225 #endif // DECISION_TREE_CLASSIFIER_HPP
Implements a Decision Tree Classifier.
Definition: DecisionTreeClassifier.hpp:20
DecisionTreeClassifier(int max_depth=5, int min_samples_split=2)
Constructs a DecisionTreeClassifier.
Definition: DecisionTreeClassifier.hpp:73
std::vector< int > predict(const std::vector< std::vector< double >> &X) const
Predicts class labels for given input data.
Definition: DecisionTreeClassifier.hpp:84
void fit(const std::vector< std::vector< double >> &X, const std::vector< int > &y)
Fits the model to the training data.
Definition: DecisionTreeClassifier.hpp:80
~DecisionTreeClassifier()
Destructor for DecisionTreeClassifier.
Definition: DecisionTreeClassifier.hpp:76