1 #ifndef HIERARCHICAL_CLUSTERING_HPP
2 #define HIERARCHICAL_CLUSTERING_HPP
46 void fit(
const std::vector<std::vector<double>>& X);
52 std::vector<int>
predict()
const;
63 std::vector<std::vector<double>> data;
67 std::vector<int> points;
70 std::vector<std::shared_ptr<Cluster>> clusters;
78 double euclidean_distance(
int a,
int b)
const;
86 double cluster_distance(
const Cluster& cluster_a,
const Cluster& cluster_b)
const;
91 void merge_clusters();
97 std::pair<int, int> find_closest_clusters()
const;
101 : n_clusters(n_clusters), linkage(linkage) {}
110 for (
size_t i = 0; i < data.size(); ++i) {
111 auto cluster = std::make_shared<Cluster>();
112 cluster->id =
static_cast<int>(i);
113 cluster->points.push_back(
static_cast<int>(i));
114 clusters.push_back(cluster);
118 while (
static_cast<int>(clusters.size()) > n_clusters) {
124 std::vector<int> labels(data.size(), -1);
125 for (
size_t i = 0; i < clusters.size(); ++i) {
126 for (
int point_idx : clusters[i]->points) {
127 labels[point_idx] =
static_cast<int>(i);
134 std::vector<std::vector<double>> centers;
135 centers.reserve(clusters.size());
137 for (
const auto& cluster : clusters) {
138 std::vector<double> centroid(data[0].size(), 0.0);
139 for (
int idx : cluster->points) {
140 const auto& point = data[idx];
141 for (
size_t i = 0; i < point.size(); ++i) {
142 centroid[i] += point[i];
146 for (
double& val : centroid) {
147 val /= cluster->points.size();
149 centers.push_back(centroid);
155 double HierarchicalClustering::euclidean_distance(
int a,
int b)
const {
156 const auto& point_a = data[a];
157 const auto& point_b = data[b];
158 double distance = 0.0;
159 for (
size_t i = 0; i < point_a.size(); ++i) {
160 double diff = point_a[i] - point_b[i];
161 distance += diff * diff;
163 return std::sqrt(distance);
166 double HierarchicalClustering::cluster_distance(
const Cluster& cluster_a,
const Cluster& cluster_b)
const {
167 double distance = 0.0;
169 if (linkage == Linkage::SINGLE) {
171 distance = std::numeric_limits<double>::max();
172 for (
int idx_a : cluster_a.points) {
173 for (
int idx_b : cluster_b.points) {
174 double dist = euclidean_distance(idx_a, idx_b);
175 if (dist < distance) {
180 }
else if (linkage == Linkage::COMPLETE) {
183 for (
int idx_a : cluster_a.points) {
184 for (
int idx_b : cluster_b.points) {
185 double dist = euclidean_distance(idx_a, idx_b);
186 if (dist > distance) {
191 }
else if (linkage == Linkage::AVERAGE) {
195 for (
int idx_a : cluster_a.points) {
196 for (
int idx_b : cluster_b.points) {
197 distance += euclidean_distance(idx_a, idx_b);
207 void HierarchicalClustering::merge_clusters() {
208 auto [idx_a, idx_b] = find_closest_clusters();
211 clusters[idx_a]->points.insert(clusters[idx_a]->points.end(),
212 clusters[idx_b]->points.begin(),
213 clusters[idx_b]->points.end());
216 clusters.erase(clusters.begin() + idx_b);
219 std::pair<int, int> HierarchicalClustering::find_closest_clusters()
const {
220 double min_distance = std::numeric_limits<double>::max();
224 for (
size_t i = 0; i < clusters.size(); ++i) {
225 for (
size_t j = i + 1; j < clusters.size(); ++j) {
226 double dist = cluster_distance(*clusters[i], *clusters[j]);
227 if (dist < min_distance) {
229 idx_a =
static_cast<int>(i);
230 idx_b =
static_cast<int>(j);
235 return {idx_a, idx_b};
Agglomerative Hierarchical Clustering for clustering tasks.
Definition: HierarchicalClustering.hpp:19
~HierarchicalClustering()
Destructor for HierarchicalClustering.
Definition: HierarchicalClustering.hpp:103
std::vector< std::vector< double > > get_cluster_centers() const
Retrieves the cluster centers (centroids) after fitting.
Definition: HierarchicalClustering.hpp:133
Linkage
Linkage criteria for clustering.
Definition: HierarchicalClustering.hpp:24
void fit(const std::vector< std::vector< double >> &X)
Fits the clustering algorithm to the data.
Definition: HierarchicalClustering.hpp:105
HierarchicalClustering(int n_clusters=2, Linkage linkage=Linkage::AVERAGE)
Constructs a HierarchicalClustering instance.
Definition: HierarchicalClustering.hpp:100
std::vector< int > predict() const
Predicts the cluster labels for the data.
Definition: HierarchicalClustering.hpp:123