K-Means Clustering

Introduction

K-means clustering is a simple unsupervised machine learning algorithm that aims to partition points in a dataset into clusters. Each cluster is defined by a mean (also called a centroid) and points are assigned to the cluster whose centroid is closest. The distance between a data point and all the centroids in a dataset is calculated, and the point is assigned to the cluster whose centroid is closest to the data point. An effective k-means clustering algorithm would therefore aim to reduce the distance between each data point and its cluster’s centroid in a dataset. The mean of the squared distances between each point and its cluster’s centroids is often used when implementing k-means clustering and is referred to as inertia.

Let’s see an example of how k-means clustering works by examining the data points in 2-Dimensional space below:

The points are all blue indicating they are not in any cluster. Our goal is to create a scatterplot with each data sample color-coded based on the cluster it belongs to.

Procedure

1.Find k, the number of clusters in the dataset. Identifying the number of clusters in the dataset is an important task that requires careful consideration. One option is to choose the number of clusters that result in the lowest inertia. This can be done by examining an elbow plot with inertia on the y-axis and the number of clusters(k) on the x-axis. We expect that as k increases from 0 to N(number of points in the dataset), inertia decreases. For the 2D dataset above, the elbow-plot I used is shown below:

From the elbow plot, we can see that inertia drops as the number of k clusters increases. It appears that the inertia doesn’t change that much as the number of clusters increases from 5 onwards. Consequently, choosing k = 5 is the most appropriate for this dataset. Another option for choosing the number of clusters is using our human intelligence to visually determine the number of naturally occurring clusters in a dataset. For instance, we can clearly see from the scatterplot above that there are 5 groups in the dataset. However, it is often not easy to determine the number of clusters this way since there’s a lot of overlap between categories in a real-world dataset. In this case, creating an elbow plot is a better option.

2. Randomly select k points to be the initial cluster centroids. In our case, we randomly select 5 points as the initial cluster centroids. Ideally, we would want to choose initial centroids that represent the 5 clusters as much as possible.

3. Calculate the distance between each centroid and all data points and assign each data point to the cluster with the centroid closes to it. In our example, we compute the distance between each of the data points and the 5 initial cluster centroids that we chose. Each data point is then assigned to the cluster with the centroid that’s closest to it. After all the points have been assigned to a cluster, we end up with 5 new clusters. Color-coding the clusters can be a great way of visually distinguishing the clusters as we’ll see in a short while.

4. Next, we update the centroids by finding the mean of the new clusters and having these means represent the new centroids. In our example, the centroids shift from the initial 5 data points that we chose to the 5 means of each cluster.

5. Our goal is to find centroids that lead to the smallest/no movement of the clusters. We, therefore, iterate over steps 1 through 4 until we find the best centroids where no data points move. In our example, the color-coded version of the clusters after implementing k-means clustering would look like the image below:

Application: Image Segmentation

K-means clustering can be used to reduce the number of colors in an image, which then allows the image to be stored in a much smaller file.

We can use clustering to find groups of colors in an image. If we then replace each image with the mean color for that group, we have segmented the image i.e. we have quantized the colors which breaks the colors into segments.

Here’s a link to my code on Github to a project I did on image segmentation using k-means clustering:

To learn more about image segmentation using K-means clustering algorithm, check out this article by Shubhang Agrawal:

Conclusion

K-Means clustering is an easy-to-implement machine learning algorithm with a multitude of real-world applications which include customer segmentation, identifying crime localities, and document classification. It’s relatively simple to implement, scales well to large data sets, generalizes well to clusters of different shapes and sizes, and easily adapts to new examples. However, it has some drawbacks. K-means is highly dependent on initial conditions such as the choice of k (number of clusters) and it doesn’t perform well with data of varying sizes and density. It is also highly affected by outliers in a dataset.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store