1 #ifndef KNN_CLASSIFIER_HPP
2 #define KNN_CLASSIFIER_HPP
7 #include <unordered_map>
36 void fit(
const std::vector<std::vector<double>>& X,
const std::vector<int>& y);
43 std::vector<int>
predict(
const std::vector<std::vector<double>>& X)
const;
47 std::vector<std::vector<double>> X_train;
48 std::vector<int> y_train;
56 double euclidean_distance(
const std::vector<double>& a,
const std::vector<double>& b)
const;
63 int predict_sample(
const std::vector<double>& x)
const;
76 std::vector<int> predictions;
77 predictions.reserve(X.size());
78 for (
const auto& x : X) {
79 predictions.push_back(predict_sample(x));
84 double KNNClassifier::euclidean_distance(
const std::vector<double>& a,
const std::vector<double>& b)
const {
85 double distance = 0.0;
86 for (
size_t i = 0; i < a.size(); ++i) {
87 double diff = a[i] - b[i];
88 distance += diff * diff;
90 return std::sqrt(distance);
93 int KNNClassifier::predict_sample(
const std::vector<double>& x)
const {
95 std::vector<std::pair<double, int>> distances;
96 distances.reserve(X_train.size());
99 for (
size_t i = 0; i < X_train.size(); ++i) {
100 double dist = euclidean_distance(x, X_train[i]);
101 distances.emplace_back(dist, y_train[i]);
105 std::nth_element(distances.begin(), distances.begin() + k, distances.end(),
106 [](
const std::pair<double, int>& a,
const std::pair<double, int>& b) {
107 return a.first < b.first;
111 std::unordered_map<int, int> class_counts;
112 for (
int i = 0; i < k; ++i) {
113 int label = distances[i].second;
114 class_counts[label]++;
119 int majority_class = -1;
120 for (
const auto& [label, count] : class_counts) {
121 if (count > max_count) {
123 majority_class = label;
127 return majority_class;
K-Nearest Neighbors Classifier for classification tasks.
Definition: KNNClassifier.hpp:18
KNNClassifier(int k=3)
Constructs a KNNClassifier.
Definition: KNNClassifier.hpp:66
void fit(const std::vector< std::vector< double >> &X, const std::vector< int > &y)
Fits the classifier to the training data.
Definition: KNNClassifier.hpp:70
std::vector< int > predict(const std::vector< std::vector< double >> &X) const
Predicts class labels for the given input data.
Definition: KNNClassifier.hpp:75
~KNNClassifier()
Destructor for KNNClassifier.
Definition: KNNClassifier.hpp:68