{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Clustering\n", "\n", "Cluster analysis or clustering is the task of grouping a set of objects in such a way that objects in the same group (called a cluster) are more similar (in some sense) to each other than to those in other groups (clusters).\n", "\n", "Here we will consider K-means clustering, where we will cluster objects into k-clusters. The clusters will be formed by determimning centroids of each cluster, then membership to the cluster is determined by an observations shortest distance to the centroid.\n", "\n", "For this problem we will work with a generated dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Import the digits data set\n", "from sklearn.datasets import load_digits\n", "digits = load_digits()\n", "digits.data.shape\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Fit and Predict 10 clusters on this dataset \n", "from sklearn.cluster import KMeans\n", "\n", "kmeans = KMeans(n_clusters=10, random_state=0)\n", "clusters = kmeans.fit_predict(digits.data)\n", "kmeans.cluster_centers_.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Visualize\n", "import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots(2, 5, figsize=(8, 3))\n", "centers = kmeans.cluster_centers_.reshape(10, 8, 8)\n", "for axi, center in zip(ax.flat, centers):\n", " axi.set(xticks=[], yticks=[])\n", " axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Match each learned cluster label with the true labels found in them:\n", "import numpy as np\n", "from scipy.stats import mode\n", "\n", "labels = np.zeros_like(clusters)\n", "for i in range(10):\n", " mask = (clusters == i)\n", " labels[mask] = mode(digits.target[mask])[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Print the accuracy score\n", "from sklearn.metrics import accuracy_score\n", "accuracy_score(digits.target, labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Visualize the confusion matrix\n", "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "mat = confusion_matrix(digits.target, labels)\n", "sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,\n", " xticklabels=digits.target_names,\n", " yticklabels=digits.target_names)\n", "plt.xlabel('true label')\n", "plt.ylabel('predicted label');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 4 }