In this tutorial, we'll briefly explore clustering data with the Mean Shift algorithm using scikit-learn's MeanShift class in Python. The tutorial covers:
- The concept of Mean Shift
- Preparing data
- Clustering with Mean Shift
- Source code listing
The concept of Mean Shift
Preparing data
We'll start by loading the required libraries.
from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
Next, we'll create a sample dataset for clustering with make_blob function and visualize it in a plot.
# Create synthetic data
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
# Visualize the data
plt.scatter(x[:,0], x[:,1])
plt.show()
Clustering with Mean Shift
Scikit-learn provides the MeanShift class to implement the algorithm. In this tutorial, we'll use this class to define the model.
We define the MeanShift model by setting the bandwidth parameter to 2, specifying the size of the window area, and fit it to the 'x' data.
# Create Mean Shift instance and fit it to the data
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
n_jobs=1, seeds=None)
Now, we can get labels (cluster ids) and center points of each cluster area.# get cluster id labels = mshclust.labels_
# get cluster centers centers = mshclust.cluster_centers_
Using the 'labels' and 'centers' data, we will visualize the clustered points by differentiating them with various colors, and we'll plot the center points of each cluster.
# Visualize original data and cluster centers
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()
Conclusion
In this tutorial, we explored the Mean Shift clustering algorithm and applied it to synthetic data. The algorithm automatically identified cluster centers without requiring us to specify the number of clusters. Mean Shift is particularly useful in scenarios where the data's natural grouping is not known in advance. Feel free to experiment with different datasets and parameters to gain a deeper understanding of Mean Shift clustering. The full source code is provided below.
Source code listing
from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
# Create synthetic data
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
# Visualize the data
plt.scatter(x[:,0], x[:,1])
plt.show()
# Create Mean Shift instance and fit it to the data
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
# get cluster id
labels = mshclust.labels_
# get cluster centers centers = mshclust.cluster_centers_
# Visualize original data and cluster centers
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()
Thanks. Helped me understand the algorithms (for a university assignment)
ReplyDelete