1 #ifndef DECISION_TREE_CLASSIFIER_HPP
2 #define DECISION_TREE_CLASSIFIER_HPP
39 void fit(
const std::vector<std::vector<double>>& X,
const std::vector<int>& y);
46 std::vector<int>
predict(
const std::vector<std::vector<double>>& X)
const;
57 Node() : is_leaf(
false), value(0), feature_index(-1), threshold(0.0), left(
nullptr), right(
nullptr) {}
62 int min_samples_split;
64 Node* build_tree(
const std::vector<std::vector<double>>& X,
const std::vector<int>& y,
int depth);
65 double calculate_gini(
const std::vector<int>& y)
const;
66 void split_dataset(
const std::vector<std::vector<double>>& X,
const std::vector<int>& y,
int feature_index,
double threshold,
67 std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
68 std::vector<std::vector<double>>& X_right, std::vector<int>& y_right)
const;
69 int predict_sample(
const std::vector<double>& x, Node* node)
const;
70 void delete_tree(Node* node);
74 : root(nullptr), max_depth(max_depth), min_samples_split(min_samples_split) {}
81 root = build_tree(X, y, 0);
85 std::vector<int> predictions;
86 for (
const auto& x : X) {
87 predictions.push_back(predict_sample(x, root));
92 DecisionTreeClassifier::Node* DecisionTreeClassifier::build_tree(
const std::vector<std::vector<double>>& X,
const std::vector<int>& y,
int depth) {
93 Node* node =
new Node();
96 if (depth >= max_depth || y.size() <
static_cast<size_t>(min_samples_split) || calculate_gini(y) == 0.0) {
99 std::map<int, int> class_counts;
100 for (
int label : y) {
101 class_counts[label]++;
103 node->value = std::max_element(class_counts.begin(), class_counts.end(),
104 [](
const std::pair<int, int>& a,
const std::pair<int, int>& b) {
105 return a.second < b.second;
110 double best_gini = std::numeric_limits<double>::max();
111 int best_feature_index = -1;
112 double best_threshold = 0.0;
113 std::vector<std::vector<double>> best_X_left, best_X_right;
114 std::vector<int> best_y_left, best_y_right;
116 int num_features = X[0].size();
117 for (
int feature_index = 0; feature_index < num_features; ++feature_index) {
119 std::vector<double> feature_values;
120 for (
const auto& x : X) {
121 feature_values.push_back(x[feature_index]);
123 std::sort(feature_values.begin(), feature_values.end());
124 std::vector<double> thresholds;
125 for (
size_t i = 1; i < feature_values.size(); ++i) {
126 thresholds.push_back((feature_values[i - 1] + feature_values[i]) / 2.0);
130 for (
double threshold : thresholds) {
131 std::vector<std::vector<double>> X_left, X_right;
132 std::vector<int> y_left, y_right;
133 split_dataset(X, y, feature_index, threshold, X_left, y_left, X_right, y_right);
135 if (y_left.empty() || y_right.empty())
138 double gini_left = calculate_gini(y_left);
139 double gini_right = calculate_gini(y_right);
140 double gini = (gini_left * y_left.size() + gini_right * y_right.size()) / y.size();
142 if (gini < best_gini) {
144 best_feature_index = feature_index;
145 best_threshold = threshold;
146 best_X_left = X_left;
147 best_X_right = X_right;
148 best_y_left = y_left;
149 best_y_right = y_right;
155 if (best_feature_index == -1) {
156 node->is_leaf =
true;
158 std::map<int, int> class_counts;
159 for (
int label : y) {
160 class_counts[label]++;
162 node->value = std::max_element(class_counts.begin(), class_counts.end(),
163 [](
const std::pair<int, int>& a,
const std::pair<int, int>& b) {
164 return a.second < b.second;
170 node->feature_index = best_feature_index;
171 node->threshold = best_threshold;
172 node->left = build_tree(best_X_left, best_y_left, depth + 1);
173 node->right = build_tree(best_X_right, best_y_right, depth + 1);
177 double DecisionTreeClassifier::calculate_gini(
const std::vector<int>& y)
const {
178 std::map<int, int> class_counts;
179 for (
int label : y) {
180 class_counts[label]++;
182 double impurity = 1.0;
183 size_t total = y.size();
184 for (
const auto& class_count : class_counts) {
185 double prob =
static_cast<double>(class_count.second) / total;
186 impurity -= prob * prob;
191 void DecisionTreeClassifier::split_dataset(
const std::vector<std::vector<double>>& X,
const std::vector<int>& y,
192 int feature_index,
double threshold,
193 std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
194 std::vector<std::vector<double>>& X_right, std::vector<int>& y_right)
const {
195 for (
size_t i = 0; i < X.size(); ++i) {
196 if (X[i][feature_index] <= threshold) {
197 X_left.push_back(X[i]);
198 y_left.push_back(y[i]);
200 X_right.push_back(X[i]);
201 y_right.push_back(y[i]);
206 int DecisionTreeClassifier::predict_sample(
const std::vector<double>& x, Node* node)
const {
210 if (x[node->feature_index] <= node->threshold) {
211 return predict_sample(x, node->left);
213 return predict_sample(x, node->right);
217 void DecisionTreeClassifier::delete_tree(Node* node) {
218 if (node !=
nullptr) {
219 delete_tree(node->left);
220 delete_tree(node->right);
Implements a Decision Tree Classifier.
Definition: DecisionTreeClassifier.hpp:20
DecisionTreeClassifier(int max_depth=5, int min_samples_split=2)
Constructs a DecisionTreeClassifier.
Definition: DecisionTreeClassifier.hpp:73
std::vector< int > predict(const std::vector< std::vector< double >> &X) const
Predicts class labels for given input data.
Definition: DecisionTreeClassifier.hpp:84
void fit(const std::vector< std::vector< double >> &X, const std::vector< int > &y)
Fits the model to the training data.
Definition: DecisionTreeClassifier.hpp:80
~DecisionTreeClassifier()
Destructor for DecisionTreeClassifier.
Definition: DecisionTreeClassifier.hpp:76