8  Clustering

TipLearning Objectives
  • By the end of this lesson, students will be able to:
  1. Explain the goal and applications of clustering
    • Define unsupervised learning and distinguish clustering from classification.
    • Cite real‑world uses (e.g., customer segmentation, anomaly detection).
  2. Describe and execute hierarchical clustering methods
    • Compute and interpret a dendrogram.
    • Apply various linkage criteria (single, complete, average, Ward’s, etc.) and distance metrics (Euclidean, Manhattan, cosine).
  3. Implement and analyze K‑Means clustering
    • Articulate the iterative assignment–update steps and convergence conditions.
    • Write or follow pseudocode for K‑Means, including centroid initialization strategies.
    • Compute and interpret the within‑cluster variation objective.
  4. Compare clustering techniques and select appropriately
    • Identify strengths and weaknesses of hierarchical vs. K‑Means approaches.
    • Choose between methods based on data characteristics (number of clusters, scalability, hierarchy needs).

8.1 Warm-up Puzzle

  • Is the picture below fake or real?

  • How can a computer determine if the picture below is fake or real?

Taj Mahal bathed in the Northern Lights. Generated using the DALL-E tool.

8.2 Clustering Overview

Clustering is an unsupervised learning technique used to group similar data points together. Unlike classification, there are no pre-defined labels. Instead, the algorithm tries to discover structure in the data by maximizing intra-cluster similarity and minimizing inter-cluster similarity.

Key points:

  • Objective: Identify natural groupings in the data.

  • Applications: Customer segmentation, image compression, anomaly detection, document clustering.

8.3 Hierarchical Clustering

Hierarchical clustering builds a tree (dendrogram) of clusters using either a bottom‑up (agglomerative) or top‑down (divisive) approach.

8.3.1 Agglomerative (Bottom‑Up)

  1. Initialization: Start with each data point as its own cluster.

  2. Merge Steps:

    • Compute distance between every pair of clusters.
    • Merge the two closest clusters.
    • Update the distance matrix.
  3. Termination: Repeat until all points are in a single cluster or a stopping criterion (e.g., desired number of clusters) is met.

8.3.2 Dendrogram

        [ALL POINTS]
         /      \
    Cluster A   Cluster B
     /    \       /    \
    …      …     …      …
  • Cutting the tree at different levels yields different numbers of clusters.
  • Linkage methods determine how distance between clusters is computed:
    • Single linkage: Minimum pairwise distance
    • Complete linkage: Maximum pairwise distance
    • Average linkage: Average pairwise distance

8.3.3 Important Concepts

Tip
  • Metric
    The metric (or distance function or dissimilarity function) defines how you measure the distance between individual data points. Common choices include Euclidean, Manhattan (cityblock), or cosine distance. This metric determines the raw pairwise distances.

Manhattan distance

Manhattan distance. Image created using DALL-E.

Manhattan distance

Manhattan distance

Euclidean distance

How would a crow navigate in Manhattan? (I have never been to Manhattan, but the internet says there are crows in Manhattan, so it must be true).

A crow in Manhattan. Image created using DreamUp.

Euclidean distance
  • Linkage
    The linkage method defines how to compute the distance between two clusters based on the pairwise distances of their members. Examples:
    • Single: the distance between the closest pair of points (one from each cluster).
    • Complete: the distance between the farthest pair of points.
    • Average: the average of all pairwise distances.
    • Ward: the merge that minimizes the increase in total within‑cluster variance.

Linkage function

Linkage function
Linkage Method How It Works Intuition
Single Distance = minimum pairwise distance between points in the two clusters “Friends‑of‑friends” – clusters join if any two points are close, yielding chain‑like clusters
Complete Distance = maximum pairwise distance between points in the two clusters “Everyone must be close” – only merge when all points are relatively near, producing compact clusters
Average (UPGMA) Distance = average of all pairwise distances between points in the two clusters Balances single and complete by averaging close and far pairs
Weighted (WPGMA) Distance = average of the previous cluster’s distance to the new cluster (equal weight per cluster) Prevents large clusters from dominating, giving equal say to each cluster
Centroid Distance = distance between the centroids (mean vectors) of the two clusters Merges based on “centers of mass,” but centroids can shift non‑monotonically
Median (WPGMC) Distance = distance between the medians of the two clusters More robust to outliers than centroid linkage, but can also invert dendrogram order
Ward’s Merge that minimizes the increase in total within‑cluster sum of squares (variance) Keeps clusters as tight and homogeneous as possible, often resulting in evenly sized groups

8.3.4 Single Linkage

  • How it works: Measures the distance between two clusters as the smallest distance between any single point in one cluster and any single point in the other.
  • Intuition: “Friends‑of‑friends” clustering—if any two points (one from each cluster) are close, the clusters join. Can produce long, straggly chains of points.

8.3.5 Complete Linkage

  • How it works: Measures the distance between two clusters as the largest distance between any point in one cluster and any point in the other.
  • Intuition: “Everyone must be close”—clusters merge only when all their points are relatively near each other, leading to tight, compact groups.

8.3.6 Average Linkage (UPGMA)

  • How it works: Takes the average of all pairwise distances between points in the two clusters.
  • Intuition: A middle‑ground between single and complete linkage—balances the effect of very close and very far pairs by averaging them.

8.3.7 Weighted Linkage (WPGMA)

  • How it works: Similar to average linkage, but treats each cluster as a single entity by averaging the distance from each original cluster to the target cluster, regardless of cluster size.
  • Intuition: Prevents larger clusters from dominating the average—gives each cluster equal say in how far apart they are.

8.3.8 Centroid Linkage

  • How it works: Computes the distance between the centroids (mean vectors) of the two clusters.
  • Intuition: Clusters merge based on whether their “centers of mass” are close. Can sometimes lead to non‑monotonic merges if centroids shift oddly.

8.3.9 Median Linkage (WPGMC)

  • How it works: Uses the median point of each cluster instead of the mean when computing distance between clusters.
  • Intuition: Like centroid linkage but more robust to outliers, since the median isn’t pulled by extreme values—though can also cause inversion issues.

8.3.10 Ward’s Method

  • How it works: At each step, merges the two clusters whose union leads to the smallest possible increase in total within‑cluster variance (sum of squared deviations).
  • Intuition: Always chooses the merge that keeps clusters as tight and homogeneous as possible, often yielding groups of similar size and shape.
TipConcept about distances

There is no single “best” distance metric for clustering—what works well for one dataset or problem may not work for another. The choice of distance metric (such as Euclidean, or Manhattan) depends on the nature of your data and what you want to capture about similarity.

For example, Euclidean distance works well when the scale of features is meaningful and differences are linear, while cosine distance is better for text data or situations where the direction of the data matters more than its magnitude.

It is important to experiment with different distance metrics and see which one produces clusters that make sense for your specific problem. Always check the results and, if possible, use domain knowledge to guide your choice.

8.4 Practical

  • n_clusters in AgglomerativeClustering specifies the number of clusters you want the algorithm to find. After building the hierarchical tree, the algorithm will cut the tree so that exactly n_clusters groups are formed. For example, n_clusters=3 will result in 3 clusters in your data.

  • The default value for n_clusters in AgglomerativeClustering is 2. The default value for linkage is ward. So if you do not specify these parameters, the algorithm will produce 2 clusters using Ward linkage.

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from scipy.cluster.hierarchy import linkage, dendrogram
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler

# Load data
iris = load_iris()
X = iris.data

# Fit Agglomerative Clustering
agg = AgglomerativeClustering(n_clusters=3, linkage='ward')
labels = agg.fit_predict(X)

print(labels)  # Cluster assignments for each sample

from scipy.cluster.hierarchy import dendrogram, linkage
import matplotlib.pyplot as plt

# Compute linkage matrix
Z = linkage(X, method='ward')

# Plot dendrogram
plt.figure(figsize=(10, 5))
dendrogram(Z)
plt.title('Hierarchical Clustering Dendrogram (Iris Data)')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 2 2 2 2 0 2 2 2 2
 2 2 0 0 2 2 2 2 0 2 0 2 0 2 2 0 0 2 2 2 2 2 0 0 2 2 2 0 2 2 2 0 2 2 2 0 2
 2 0]

8.4.1 Scale the data

from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
import pandas as pd

# Load data
iris = load_iris()
X = iris.data

# let us see what is in the data
df_iris = load_iris(as_frame=True)
data_frame_iris = df_iris.frame
print(data_frame_iris)
#print(df_iris.frame.head())

# scale the data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \
0                  5.1               3.5                1.4               0.2   
1                  4.9               3.0                1.4               0.2   
2                  4.7               3.2                1.3               0.2   
3                  4.6               3.1                1.5               0.2   
4                  5.0               3.6                1.4               0.2   
..                 ...               ...                ...               ...   
145                6.7               3.0                5.2               2.3   
146                6.3               2.5                5.0               1.9   
147                6.5               3.0                5.2               2.0   
148                6.2               3.4                5.4               2.3   
149                5.9               3.0                5.1               1.8   

     target  
0         0  
1         0  
2         0  
3         0  
4         0  
..      ...  
145       2  
146       2  
147       2  
148       2  
149       2  

[150 rows x 5 columns]

8.4.2 Perform hierarchical clustering

from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import linkage, dendrogram
import matplotlib.pyplot as plt

# Fit Agglomerative Clustering
agg = AgglomerativeClustering(n_clusters=3)#, linkage='ward')
labels = agg.fit_predict(X_scaled)

print(labels)  # Cluster assignments for each sample

# Compute linkage matrix
Z = linkage(X_scaled)#, method='ward')

# Plot dendrogram
plt.figure()
dendrogram(Z)
plt.title('Hierarchical Clustering Dendrogram (Iris Data)')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 2 1 1 1 1 1 1 1 1 0 0 0 2 0 2 0 2 0 2 2 0 2 0 2 0 2 2 2 2 0 0 0 0
 0 0 0 0 0 2 2 2 2 0 2 0 0 2 2 2 2 0 2 2 2 2 2 0 2 2 0 0 0 0 0 0 2 0 0 0 0
 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0]

8.5 Alternative code using sns.clustermap

There are many ways we can perform hierarchical clustering. An alternative is to use sns.clustermap using the seaborn package (see here). The interface is similar and it can produce professional quality plots.


import seaborn as sns

# Basic clustermap with scaling
sns.clustermap(X, cmap='RdBu_r', z_score=0, center=0)

# Different linkage methods
sns.clustermap(X, method='average', z_score=0, center = 0)

# Different distance metrics  
sns.clustermap(X, metric='correlation', method='average', z_score=0, center = 0)

# Comprehensive example
sns.clustermap(X, method='average', metric='correlation', 
               cmap='RdBu_r', z_score=0, center=0)

# or if you prefer just the default options
sns.clustermap(X, z_score=0, center=0)

z_score=0 (Scaling Direction) - What it does: Standardizes (z-scores) the data before clustering - Options: - 0: Scale rows (genes) - each gene’s expression is standardized across samples - 1: Scale columns (samples) - each sample’s expression is standardized across genes - None: No scaling (use raw data) - Why use 0: For gene expression, you want to compare expression patterns, not absolute levels

cmap='RdBu_r' (Color Map) - What it does: Defines the color scheme for the heatmap - 'RdBu_r': Red-Blue reversed (red = high, blue = low, white = middle) - Other options: 'viridis', 'coolwarm', 'seismic', 'plasma', etc. - Why use it: Intuitive for biologists (red = high expression, blue = low expression)

center=0 (Color Center) - What it does: Centers the color map at this value - 0: White color represents zero (after scaling, this is the mean) - Other values: Could center at 1 (for fold-change), or other biologically meaningful values - Why use it: Makes it easy to see above/below average expression

Additional Common Parameters

row_cluster=True/False

  • What it does: Whether to cluster rows (genes)
  • Default: True

col_cluster=True/False

  • What it does: Whether to cluster columns (samples)
  • Default: True

cbar_kws (Color Bar Keywords)

  • What it does: Customize the color bar

  • Example: cbar_kws={'label': 'Expression Level', 'shrink': 0.8}

  • Here is some code to use sns.clustermap

from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import linkage, dendrogram
import matplotlib.pyplot as plt
import seaborn as sns

sns.clustermap(data_frame_iris, z_score=0, center=0)

8.6 Exercise (changing the linkage function)

  • Work in a group for this exercise

  • Let us try another linkage function

  • Change the linkage function in Z = linkage(X_scaled), method='ward')

  • How does the clustering change as you change this to another function?

  • How does this change if you do not scale the data?

# Fit Agglomerative Clustering on unscaled data with 'average' linkage
agg = AgglomerativeClustering(linkage='average')
labels = agg.fit_predict(X)

# Compute linkage matrix on unscaled data with 'average' linkage
Z = linkage(X, method='average')

# Plot dendrogram
plt.figure()
dendrogram(Z)
plt.title('Hierarchical Clustering Dendrogram (Iris Data, Unscaled, Average Linkage)')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()

8.7 Exercise (trying a different dissimilarity metric)

# Fit Agglomerative Clustering on unscaled data with 'average' linkage and 'manhattan' distance
agg = AgglomerativeClustering(
    linkage='average',
    metric='manhattan'      # use metric instead of deprecated affinity
)


# Compute linkage matrix on unscaled data with 'average' linkage and 'cityblock' (manhattan) distance
Z = linkage(X, method='average', metric='cityblock')

# Plot dendrogram
plt.figure()
dendrogram(Z)
plt.title('Hierarchical Clustering Dendrogram (Iris Data, Unscaled, Average Linkage, Manhattan Distance)')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()

8.8 Exercise with missing data

Real-world data frequently has missing values. PCA and clustering techniques can struggle on missing data.

In this exercise, you will work in a group and apply hierarchical clustering, PCA and tSNE on data which has missing values.

Run the code below. All the missing data will be available in the variable missing_data.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.preprocessing import StandardScaler
from scipy.cluster.hierarchy import dendrogram, linkage
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

def create_synthetic_biological_data():
    """
    Create synthetic gene expression data with biological structure and missingness.
    
    Returns:
    - complete_data: Original data without missing values
    - missing_data: Data with various missingness patterns
    - true_labels: True cluster labels for evaluation
    """
    #print("Creating synthetic biological dataset...")
    
    # Parameters
    n_samples = 100
    n_genes = 50
    n_clusters = 4
    
    # Create base data structure
    data = np.random.normal(0, 1, (n_samples, n_genes))
    
    # Add biological structure (clusters)
    cluster_size = n_samples // n_clusters
    
    # Cluster 1: High expression in genes 0-12, samples 0-24
    data[0:25, 0:13] += 2.5
    # Cluster 2: High expression in genes 13-25, samples 25-49  
    data[25:50, 13:26] += 2.0
    # Cluster 3: High expression in genes 26-37, samples 50-74
    data[50:75, 26:38] += 1.8
    # Cluster 4: Low expression in genes 38-49, samples 75-99
    data[75:100, 38:50] -= 2.2
    
    # Add some noise
    data += np.random.normal(0, 0.5, data.shape)
    
    # Create sample and gene names
    sample_names = [f'Sample_{i:03d}' for i in range(n_samples)]
    gene_names = [f'Gene_{chr(65+i//26)}{chr(65+i%26)}' for i in range(n_genes)]
    
    # Create DataFrame
    complete_data = pd.DataFrame(data, index=sample_names, columns=gene_names)
    
    # Create true cluster labels
    true_labels = np.repeat(range(n_clusters), cluster_size)
    if len(true_labels) < n_samples:
        true_labels = np.append(true_labels, [n_clusters-1] * (n_samples - len(true_labels)))
    
    #print(f"Created dataset: {complete_data.shape[0]} samples × {complete_data.shape[1]} genes")
    #print(f"True clusters: {n_clusters}")
    
    return complete_data, true_labels

def introduce_missing_data_patterns(complete_data, true_labels):
    """
    Introduce different types of missing data patterns.
    
    Parameters:
    - complete_data: Original complete dataset
    - true_labels: True cluster labels
    
    Returns:
    - missing_data: Dataset with missing values
    - missing_info: Information about missingness patterns
    """
    #print("\nIntroducing missing data patterns...")
    
    missing_data = complete_data.copy()
    missing_info = {}
    
    # Pattern 1: Missing Completely At Random (MCAR) - 5% random missing
    #print("1. Adding MCAR missingness (5% random)...")
    mcar_mask = np.random.random(missing_data.shape) < 0.05
    missing_data[mcar_mask] = np.nan
    missing_info['MCAR'] = mcar_mask.sum()
    
    # Pattern 2: Missing At Random (MAR) - correlated with expression level
    #print("2. Adding MAR missingness (correlated with high expression)...")
    # Higher chance of missing for high expression values
    high_expr_mask = missing_data > missing_data.quantile(0.8)
    mar_probability = np.where(high_expr_mask, 0.15, 0.02)  # 15% for high, 2% for low
    mar_mask = np.random.random(missing_data.shape) < mar_probability
    missing_data[mar_mask] = np.nan
    missing_info['MAR'] = mar_mask.sum()
    
    # Pattern 3: Missing Not At Random (MNAR) - systematic missing
    #print("3. Adding MNAR missingness (systematic missing)...")
    # Missing entire samples (simulating failed experiments)
    failed_samples = np.random.choice(missing_data.index, size=8, replace=False)
    missing_data.loc[failed_samples, :] = np.nan
    missing_info['MNAR_samples'] = len(failed_samples)
    
    # Missing entire genes (simulating detection failures)
    failed_genes = np.random.choice(missing_data.columns, size=5, replace=False)
    missing_data.loc[:, failed_genes] = np.nan
    missing_info['MNAR_genes'] = len(failed_genes)
    
    # Pattern 4: Block missingness (simulating batch effects)
    #print("4. Adding block missingness (batch effects)...")
    # Missing blocks of data (simulating different experimental conditions)
    block_start_row = 20
    block_end_row = 35
    block_start_col = 10
    block_end_col = 20
    missing_data.iloc[block_start_row:block_end_row, block_start_col:block_end_col] = np.nan
    missing_info['Block'] = (block_end_row - block_start_row) * (block_end_col - block_start_col)
    
    # Calculate total missingness
    total_missing = missing_data.isnull().sum().sum()
    total_values = missing_data.size
    missing_percentage = (total_missing / total_values) * 100
    
    print(f"\nMissing data summary:")
    print(f"Total missing values: {total_missing}")
    print(f"Missing percentage: {missing_percentage:.1f}%")
    #print(f"MCAR: {missing_info['MCAR']} values")
    #print(f"MAR: {missing_info['MAR']} values") 
    #print(f"MNAR samples: {missing_info['MNAR_samples']} samples")
    #print(f"MNAR genes: {missing_info['MNAR_genes']} genes")
    #print(f"Block missing: {missing_info['Block']} values")
    
    return missing_data, missing_info

def visualize_missing_patterns(complete_data, missing_data, true_labels):
    """
    Visualize the missing data patterns.
    """
    print("\nCreating missing data visualizations...")
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot 1: Complete data heatmap
    plt.subplot(2, 2, 1)
    sns.heatmap(complete_data.iloc[:50, :30], cmap='RdBu_r', center=0, 
                cbar_kws={'label': 'Expression Level'})
    plt.title('Complete Data (First 50 samples, 30 genes)')
    plt.xlabel('Genes')
    plt.ylabel('Samples')
    
    # Plot 2: Missing data pattern
    plt.subplot(2, 2, 2)
    missing_mask = missing_data.iloc[:50, :30].isnull()
    sns.heatmap(missing_mask, cmap='Reds', cbar_kws={'label': 'Missing (1=Yes, 0=No)'})
    plt.title('Missing Data Pattern')
    plt.xlabel('Genes')
    plt.ylabel('Samples')
    
    # Plot 3: Missing data by sample
    plt.subplot(2, 2, 3)
    missing_by_sample = missing_data.isnull().sum(axis=1)
    plt.bar(range(len(missing_by_sample)), missing_by_sample)
    plt.title('Missing Values per Sample')
    plt.xlabel('Sample Index')
    plt.ylabel('Number of Missing Values')
    
    # Plot 4: Missing data by gene
    plt.subplot(2, 2, 4)
    missing_by_gene = missing_data.isnull().sum(axis=0)
    plt.bar(range(len(missing_by_gene)), missing_by_gene)
    plt.title('Missing Values per Gene')
    plt.xlabel('Gene Index')
    plt.ylabel('Number of Missing Values')
    
    plt.tight_layout()
    plt.show()


# Create synthetic data
complete_data, true_labels = create_synthetic_biological_data()
    
# Introduce missing data patterns
missing_data, missing_info = introduce_missing_data_patterns(complete_data, true_labels)
    
# Visualize missing patterns
#visualize_missing_patterns(complete_data, missing_data, true_labels)

# All the missing data is available in the variable missing_data
# fill in your code here ...

Missing data summary:
Total missing values: 1346
Missing percentage: 26.9%
  • All the missing data is available in the variable missing_data. Now perform hierarchical clustering, PCA and tSNE on this data.

8.9 K‑Means Clustering

K‑Means is a partitional clustering algorithm that aims to partition the data into K disjoint clusters.

8.9.1 Algorithm Steps

  1. Choose K, the number of clusters.

  2. Initialization: Randomly select K initial centroids (or use k‑means++ for better seeding).

  3. Assignment Step:

    for each data point x_i:
        assign x_i to cluster j whose centroid μ_j is nearest (minimize ||x_i - μ_j||²)
  4. Update Step:

    for each cluster j:
        μ_j = (1 / |C_j|) * sum_{x_i in C_j} x_i
  5. Convergence Check:

    • Stop when assignments no longer change, OR
    • The change in centroids is below a threshold, OR
    • A maximum number of iterations is reached.

8.9.2 Animation

Animation of k-means

8.9.3 Within–Cluster Variation

In \(K\)‑means clustering, we partition our \(n\) observations into \(K\) disjoint clusters \(\{C_1, C_2, \dots, C_K\}\). A “good” clustering is one for which the within‑cluster variation is minimized.

8.9.4 Elbow point

When using k‑means clustering, one of the key questions is: how many clusters (k) should I choose? The elbow method is a simple, visual way to pick a reasonable k by looking at how the “within‑cluster” variation decreases as k increases.


1. The Within‑Cluster Sum of Squares (WCSS)

For each choice of k, you run k‑means and compute the within‑cluster sum of squares (WCSS), also called inertia or distortion. This is the sum of squared Euclidean distances between each point and the centroid of its cluster:

WCSS
  • \(C_{i}\) is cluster i
  • \(mu_{i}\) is the centroid of cluster i

As k increases, WCSS will always decrease (or stay the same), because more centroids can only reduce distances.

2. Plotting WCSS versus k

  1. Choose a range for k (e.g. 1 to 10).
  2. For each k, fit k‑means and record WCSS(k).
  3. Plot WCSS(k) on the y-axis against k on the x-axis.

You will get a curve that starts high at k = 1 and steadily goes down as k increases.


3. Identifying the “Elbow”

  • At first, adding clusters dramatically reduces WCSS, because you are splitting large, heterogeneous clusters into more homogeneous groups.
  • After some point, adding more clusters yields diminishing returns—each new cluster only slightly reduces WCSS.

The elbow point is the value of k at which the decrease in WCSS “bends” most sharply: like an elbow in your arm. It balances model complexity (more clusters) against improved fit (lower WCSS).

An elbow point

8.9.5 Exercise (k-means)

  • NOTE: The c parameter in plt.scatter() is used to specify the color of the scatter plot points.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

# Load data
iris = load_iris()
X = iris.data

# Perform k means with k = 2
kmeans = KMeans(n_clusters=2, random_state=2, n_init=20)
kmeans.fit(X)

# The cluster assignments of the observations are contained in kmeans.labels_
kmeans.labels_

# Plot the data, with each observation colored according to its cluster assignment.
plt.figure()
plt.scatter(X[:,0], X[:,1], c=kmeans.labels_)
plt.title("K-means on Iris data")
plt.show()

8.9.6 Exercise (Evaluate k-means clusters using the within-cluster similarity)

  • Find the optimal value of the number of clusters \(K\)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

# Load data
iris = load_iris()
X = iris.data

# Plot WCSS (inertia) as a function of the number of clusters
wcss = []

# try a few different values of k
k_range = range(1,11)

for k_var in k_range:

   # fit kmeans
   kmeans = KMeans(n_clusters=k_var, random_state=2)
   kmeans.fit(X)

   # append WCSS to a list
   wcss.append(kmeans.inertia_)


# plot
plt.figure()
plt.scatter( k_range , wcss )
plt.xlabel('Number of clusters')
plt.ylabel('Within-Cluster Sum of Squares (WCSS)')
plt.title('Elbow Method For Optimal')
plt.show()

8.10 Choosing Between Methods

8.10.1 Hierarchical Clustering

  • No need to pre-specify number of clusters (can decide by cutting dendrogram).
  • Produces a full hierarchy of clusters.

8.10.2 K‑Means

  • Requires pre-specifying \(K\).

8.11 Summary

TipKey Points
  • Hierarchical clustering builds a tree-like structure (dendrogram) to group similar data points
    • Distance Metrics How we measure similarity between points:
    • Euclidean distance (straight-line distance)
    • Manhattan distance (city-block distance)
    • Cosine distance (for directional data)
  • Linkage Methods
    • How we measure distance between clusters:
    • Single: Uses closest pair of points (can create chains)
    • Complete: Uses farthest pair of points (creates compact clusters)
    • Average: Uses average of all pairwise distances (balanced approach)
    • Ward’s: Minimizes increase in variance (creates similar-sized clusters)
  • Dendrogram
    • Visual representation showing how clusters merge:
    • Height shows distance when clusters merged
    • Cutting at different heights gives different numbers of clusters
  • Key Code Patterns:
import seaborn as sns

sns.clustermap(X, method='average', metric='correlation', z_score = 0, center=0)
# Basic hierarchical clustering
from sklearn.cluster import AgglomerativeClustering
agg = AgglomerativeClustering(n_clusters=3, linkage='ward')
labels = agg.fit_predict(X)

# Create dendrogram
from scipy.cluster.hierarchy import dendrogram, linkage
Z = linkage(X, method='ward')
dendrogram(Z)
  • Important Takeaways

  • No “Correct” Answer

    • Unsupervised learning requires interpretation and domain knowledge
    • Multiple Methods - Try different linkage methods and distance metrics
    • Evaluation is Key - Use both internal (silhouette) and external (ARI) metrics. See the next chapter for more on evaluation.

8.12 Further Reading