The Gaussian Mixture Model (GMM) is a probabilistic model to represent a mixture of multiple Gaussian distributions on population data. The model is widely used in clustering problems. Unlike traditional clustering methods like K-Means, GMM allows for more flexibility in the shape and orientation of clusters.
The Scikit-learn API provides the GaussianMixture class to implement Gaussian Mixture model for clustering data
In this
tutorial, you'll briefly learn how to cluster data by using GaussianMixture class in
Python. The tutorial covers:
- Preparing data.
- Clustering with Gaussian Mixture
- Source code listing
from sklearn.mixture import GaussianMixture
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from numpy import random
from pandas import DataFrame
Preparing data
First, we'll create simple clustering data for this tutorial and visualize it in a plot.
random.seed(234)
x, _ = make_blobs(n_samples=330, centers=5, cluster_std=1.84)
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1])
plt.show()
Clustering with Gaussian Mixture
Next, we'll define the Gaussian Mixture model and fit on x data. Here, we'll divide data into 5 clusters, so we set target cluster number in n_components parameter. You can also change other default parameters based on your data and clustering approach.
gm = GaussianMixture(n_components=5).fit(x)
gm.get_params()
{'covariance_type': 'full',
'init_params': 'kmeans',
'max_iter': 100,
'means_init': None,
'n_components': 5,
'n_init': 1,
'precisions_init': None,
'random_state': None,
'reg_covar': 1e-06,
'tol': 0.001,
'verbose': 0,
'verbose_interval': 10,
'warm_start': False,
'weights_init': None}
After fitting the model we can obtain centers of each cluster.
centers = gm.means_
print(centers)
[[-5.55710852 3.87061249]
[ 8.08308692 9.17642055]
[-9.18419799 -4.47855075]
[-0.89184344 0.17602145]
[ 7.31671999 2.46693378]]
Taken centers can be visualized in a plot as shown below.
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1], label="data")
plt.scatter(centers[:,0], centers[:,1],c='r', label="centers")
plt.legend()
plt.show()
We predict x data with trained model to identify each elements target center. Below code shows how to group elements and visualize the clusters with color in a plot.
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
ig, ax = plt.subplots()
for name, group in groups:
ax.scatter(group.x, group.y, label=name)
ax.legend()
plt.show()
Graph shows all the clusters and their belonging elements.
In below example, we change the clusters number and observe divided clusters in a plot.
f = plt.figure(figsize=(8, 6), dpi=80)
f.add_subplot(2, 2, 1)
for i in range(2, 6):
gm = GaussianMixture(n_components=i).fit(x)
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
f.add_subplot(2, 2, i-1)
for name, group in groups:
plt.scatter(group.x, group.y, label=name, s=8)
plt.title("Cluster size:" + str(i))
plt.legend()
plt.tight_layout()
plt.show()
In this tutorial, we've briefly learned how to cluster data with the Gaussian Mixture model in Python. The source code is listed below.
Source code listing
from sklearn.mixture import GaussianMixture
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from numpy import random
from pandas import DataFrame
random.seed(234)
x, _ = make_blobs(n_samples=330, centers=5, cluster_std=1.84)
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1])
plt.show()
gm = GaussianMixture(n_components=5).fit(x)
centers = gm.means_
print(centers)
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1], label="data")
plt.scatter(centers[:,0], centers[:,1],c='r', label="centers")
plt.legend()
plt.show()
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
ig, ax = plt.subplots(figsize=(8, 6))
for name, group in groups:
ax.scatter(group.x, group.y, label=name)
ax.legend()
plt.show()
f = plt.figure(figsize=(8, 6), dpi=80)
f.add_subplot(2, 2, 1)
for i in range(2, 6):
gm = GaussianMixture(n_components=i).fit(x)
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
f.add_subplot(2, 2, i-1)
for name, group in groups:
plt.scatter(group.x, group.y, label=name, s=8)
plt.title("Cluster size:" + str(i))
plt.legend()
plt.tight_layout()
plt.show()
References:
Scikit-learn Gaussian Mixture
Thank you.
ReplyDeleteError: No module named 'sklearn.datasets.samples_generator'
ReplyDeletesolution:
https://stackoverflow.com/questions/65898399/no-module-named-sklearn-datasets-samples-generator
Thanks alot. Didn't find anywhere.
ReplyDelete