top of page
Nikhil Adithyan

Customer Segmentation with K-Means in Python

Updated: Apr 19, 2023

Learn to build and visualize K-means models to solve clustering problems



K-Means Clustering


The K-Means clustering beams at partitioning the ‘n’ number of observations into a mentioned number of ‘k’ clusters (produces sphere-like clusters). The K-Means is an unsupervised learning algorithm and one of the simplest algorithms used for clustering tasks. The K-Means divides the data into non-overlapping subsets without any cluster-internal structure. The values which are within a cluster are very similar to each other but, the values across different clusters vary enormously. K-Means clustering works really well with medium and large-sized data.


Despite the algorithm’s simplicity, K-Means is still powerful for clustering cases in data science. In this article, we are going to tackle a clustering problem which is customer segmentation (dividing customers into groups based on similar characteristics) using the K-means algorithm. Now let’s see a little bit about the case we are going to solve.


Case


Imagine that you have a customer dataset, and you need to apply customer segmentation to this historical data. Customer segmentation is the practice of partitioning a customer base into groups of individuals that have similar characteristics. It is a significant strategy as a business can target these specific groups of customers and effectively allocate marketing resources. For example, one group might contain customers who are high-profit and low-risk, that is, more likely to purchase products, or subscribe for a service. A business task is to retain those customers. Another group might include customers from non-profit organizations. Now let’s use the K-Means algorithm to segment customers based on characteristics provided in the data with python.


Steps Involved

  1. Importing the required packages

  2. Importing the customer data into the python environment

  3. Analyzing the data and find some useful information

  4. Processing the data to our needs

  5. Building the model using the K-Means algorithm

  6. Analyzing and visualizing the built K-Means model

Without further ado, let’s dive into the coding part!


Importing the Packages


Every task must begin with importing the required packages into the respective environment (python in our case). Our primary packages include pandas for working on the data, NumPy for working with the arrays, matplotlib & seaborn for visualization, mplot3d for three-dimensional visualization, and finally scikit-learn for building the K-Means model. Let’s import all the primary packages into our python environment.


Python Implementation:


# IMPORTING PACKAGES

import pandas as pd # working with data
import numpy as np # working with arrays
import matplotlib.pyplot as plt # visualization
import seaborn as sb # visualization
from mpl_toolkits.mplot3d import Axes3D # 3d plot
from termcolor import colored as cl # text customization

from sklearn.preprocessing import StandardScaler # data normalization
from sklearn.cluster import KMeans # K-means algorithm

plt.rcParams['figure.figsize'] = (20, 10)
sb.set_style('whitegrid')

Now that we have imported all the required primary packages into our python environment. Let’s proceed to import the customer characteristics data.


Importing Data


About the data: The data we are going to use contains various characteristics and information about customers (find the data here). The attributes include ‘Customer ID’, ‘Age’, ‘Education’, ‘Years Employed’, ‘Income’, ‘Card Debt’, ‘Other Debt’, ‘Defaulted’, ‘Debt Income Ratio’. In this, the attributes ‘Education’ and ‘Defaulted’ are categorical variables, and every other are a discrete variable.


We will use the ‘read_csv’ method provided by the Pandas package to read and import the data into our python environment. We are using the ‘read_csv’ method because the data we are going to use is in the ‘.csv’ format. If it is an excel sheet, it is recommended to use the ‘read_excel’ method to read it into the python environment. Now let’s import our data in python!


Python Implementation:


# IMPORTING DATA

df = pd.read_csv('cust_seg.csv')
df.drop('Unnamed: 0', axis = 1, inplace = True)
df.set_index('Customer Id', inplace = True)

df.head()

Output:



Now that we have successfully imported our customer segmentation data into our python environment. Let’s explore and gain some information about the data.


Data Analysis


Using the customer segmentation data, we are going to do some analysis and produce some visuals to extract useful information about the data.


We will start by analyzing the age distribution among the customers. Being a distribution case, we can produce a distribution plot to analyze it more effectively. Let’s do it in python!


Python Implementation:


# Age distribution

print(cl(df['Age'].describe(), attrs = ['bold']))

sb.distplot(df['Age'], 
            color = 'orange')
plt.title('AGE DISTRIBUTION', 
          fontsize = 18)
plt.xlabel('Age', 
           fontsize = 16)
plt.ylabel('Frequency', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)

plt.savefig('age_distribution.png')
plt.show()

Output:


By seeing the graph we can understand that the age with the highest number of customers is around 35–40 and the lowest customer count is between 50–60.


Next, using the ‘Defaulted’ attribute in the dataset, we can see how many of the customers are defaulted (1) on their credit card and how many of them are not (0) also the percentage of credit card default cases. To analyze the count value of the default cases, we can produce a count plot using the ‘countplot’ method in seaborn.


Python Implementation:


# Credit card default cases

default = df[df['Defaulted'] == 1.0]
non_default = df[df['Defaulted'] == 0.0]

print(cl('.......................................', attrs = ['bold']))
print(cl('Number of Default cases are {}'.format(len(default)), attrs = ['bold']))
print(cl('.......................................', attrs = ['bold']))
print(cl('Number of Non-Default cases are {}'.format(len(non_default)), attrs = ['bold']))
print(cl('.......................................', attrs = ['bold']))
print(cl('Percentage of Default cases is {:.0%}'.format(len(default)/len(non_default)), attrs = ['bold']))
print(cl('.......................................', attrs = ['bold']))

sb.countplot(df['Defaulted'], 
             palette = ['coral', 'deepskyblue'], 
             edgecolor = 'darkgrey')
plt.title('Credit card default cases(1) and non-default cases(0)', 
          fontsize = 18)
plt.xlabel('Default value', 
           fontsize = 16)
plt.ylabel('Number of People', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)

plt.savefig('default_cases.png')
plt.show()

Output:



It is understood that most of the customers have not defaulted on their credit cards. To be more precise, the number of customers who have defaulted is 183 and the number of non-default customers is 517 which means 35% of the customers have defaulted on their credit card.


Now using a scatter plot, let’s examine the relationship between the ‘Age’ and ‘Income’ attributes. A scatter plot can be feasibly produced using the ‘scatterplot’ method provided by the seaborn package. let’s do it in python!


Python Implementation:


# Age vs Income

sb.scatterplot('Age', 'Income', 
               data = df, 
               color = 'deepskyblue', 
               s = 150, 
               alpha = 0.6, 
               edgecolor = 'b')
plt.title('AGE / INCOME', 
          fontsize = 18)
plt.xlabel('Age', 
           fontsize = 16)
plt.ylabel('Income', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)

plt.savefig('age_income.png')
plt.show()

Output:


As we can see, when the age increases, the income also increases. So we can say that the attributes ‘Age’ and ‘Income’ are having a linear relationship.


Next, we are going to produce a bubble plot that represents the data points of the ‘Income’, ‘Years Employed’, and ‘DebtIncomeRatio’ attributes. The only difference between the scatter plot and bubble plot is that the bubble plot represents the data points of three attributes whereas, in the scatter plot only two are visualized. To produce a bubble plot, we can use the ‘scatterplot’ method but change the scatter size to be of the ‘DebtIncomeRatio’ attribute. Let’s do it in python!


Python Implementation:


# Years Employed vs Income

area = df.DebtIncomeRatio **2

sb.scatterplot('Years Employed', 'Income', 
               data = df, 
               s = area, 
               alpha = 0.6, 
               edgecolor = 'white', 
               hue = 'Defaulted', 
               palette = 'spring')
plt.title('YEARS EMPLOYED / INCOME', 
          fontsize = 18)
plt.xlabel('Years Employed', 
           fontsize = 16)
plt.ylabel('Income', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.legend(loc = 'upper left', fontsize = 14)

plt.savefig('y_income.png')
plt.show()

Output:


The above chart shows the data points of three attributes and along with that the scatter points are colored based on the ‘Defaulted’ attribute. Also, we can observe a linear relationship between the ‘Years Employed’ and ‘Income’ attributes.


With that, we are have successfully explored the data in every aspect we could. So let’s proceed to the next step.


Data Processing


In this step, we are going to normalize the dataset and it is very important to build our model. But what is normalization? Normalization is a statistical method that helps mathematical-based algorithms to interpret features with different magnitudes and distributions equally. Using the ‘StandardScaler’ function provided by the scikit-learn package, we can feasibly perform normalization over the dataset in python.


Python Implementation:


# DATA PROCESSING

X = df.values
X = np.nan_to_num(X)

sc = StandardScaler()

cluster_data = sc.fit_transform(X)
print(cl('Cluster data samples : ', attrs = ['bold']), cluster_data[:5])

Output:



Now we have all the required components to build the K-Means model. So, we can proceed to build the model.


Modeling


We can build the K-Means in python using the ‘KMeans’ algorithm provided by the scikit-learn package.


The KMeans class has many parameters that can be used, but we will be using these three:

  • init - Initialization method of the centroids. The value will be: ‘k-means++’. k-means++ - Selects initial cluster centers for the k-means clustering in a smart way to speed up convergence.

  • n_clusters - The number of clusters to form as well as the number of centroids to generate. The value will be 4 (since we have 4 centers)

  • n_init - Number of times the k-means algorithm will be run with different centroid seeds. The final results will be the best output of n_init consecutive runs in terms of inertia. The value will be 12

After building the model, we will be fitting and define a variable ‘labels’ to store the cluster labels of the built model. Let’s do it in python!


Python Implementation:


# MODELING

clusters = 3
model = KMeans(init = 'k-means++', 
               n_clusters = clusters, 
               n_init = 12)
model.fit(X)

labels = model.labels_
print(cl(labels[:100], attrs = ['bold']))

Output:



Now we have successfully built and fitted our K-Means model and stores the cluster labels into the ‘labels’ variable. Using the labels produced by the model we can find some useful insights about the model and conclude.


Model Insights


To begin finding some useful insights, we have to add a column to the customer data that shows the cluster value for each row. Let’s do it in python!


Python Implementation:


df['cluster_num'] = labels
df.head()

Output:



As you can see, we have created a new attribute called ‘cluster_num’ in the customer data that represents which cluster value does each of the rows belongs to.


Now let’s use the ‘groupby’ method to group the cluster value and see the mean value of each of the attributes in the dataset using the ‘mean’ method.


Python Implementation:


df.groupby('cluster_num').mean()

Output:



Let’s look at the distribution of customers based on their age and income using a bubble plot and the color represents the cluster value. Let’s do it in python!


Python Implementation:


area = np.pi * (df.Edu) ** 4

sb.scatterplot('Age', 'Income', 
               data = df, 
               s = area, 
               hue = 'cluster_num', 
               palette = 'spring', 
               alpha = 0.6, 
               edgecolor = 'darkgrey')
plt.title('AGE / INCOME (CLUSTERED)', 
          fontsize = 18)
plt.xlabel('Age', 
           fontsize = 16)
plt.ylabel('Income', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.legend(loc = 'upper left', fontsize = 14)

plt.savefig('c_age_income.png')
plt.show()

Output:



Instead of analyzing a two-dimensional bubble plot, it would be more effective to analyze a three-dimensional scatter plot and it will also be a lot easier to conclude. We can produce a three-dimensional scatter plot using the ‘mplot3d’ package in python. Follow the code to produce a three-dimensional scatter plot in python with the use of the ‘mplot3d’ package.


Python Implementation:


fig = plt.figure(1)
plt.clf()
ax = Axes3D(fig, 
            rect = [0, 0, .95, 1], 
            elev = 48, 
            azim = 134)

plt.cla()
ax.scatter(df['Edu'], df['Age'], df['Income'], 
           c = df['cluster_num'], 
           s = 200, 
           cmap = 'spring', 
           alpha = 0.5, 
           edgecolor = 'darkgrey')
ax.set_xlabel('Education', 
              fontsize = 16)
ax.set_ylabel('Age', 
              fontsize = 16)
ax.set_zlabel('Income', 
              fontsize = 16)

plt.savefig('3d_plot.png')
plt.show()

Output:



Our K-Means model has partitioned the customers into mutually exclusive groups, which are three clusters in our case. The customers in each cluster are similar to each other demographically. Now we can create a profile for each group, considering the common characteristics of each cluster. For example, the 3 clusters can be:

  • Affluent, Educated & Old Aged

  • Middle-Aged & Middle Income

  • Young & Low Income

That’s it! We have successfully built our K-Means model and did customer segmentation out of it. Hope you find this article useful and thank you very much and I’ve also provided the full code at the end of this article.


Happy Machine Learning!


Full code:


# IMPORTING PACKAGES

import pandas as pd # working with data
import numpy as np # working with arrays
import matplotlib.pyplot as plt # visualization
import seaborn as sb # visualization
from mpl_toolkits.mplot3d import Axes3D # 3d plot
from termcolor import colored as cl # text customization

from sklearn.preprocessing import StandardScaler # data normalization
from sklearn.cluster import KMeans # K-means algorithm

plt.rcParams['figure.figsize'] = (20, 10)
sb.set_style('whitegrid')

# IMPORTING DATA

df = pd.read_csv('cust_seg.csv')
df.drop('Unnamed: 0', axis = 1, inplace = True)
df.set_index('Customer Id', inplace = True)

print(df.head())

# Age distribution

print(cl(df['Age'].describe(), attrs = ['bold']))

sb.distplot(df['Age'], 
            color = 'orange')
plt.title('AGE DISTRIBUTION', 
          fontsize = 18)
plt.xlabel('Age', 
           fontsize = 16)
plt.ylabel('Frequency', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)

plt.savefig('age_distribution.png')
plt.show()

# Credit card default cases

default = df[df['Defaulted'] == 1.0]
non_default = df[df['Defaulted'] == 0.0]

print(cl('.......................................', attrs = ['bold']))
print(cl('Number of Default cases are {}'.format(len(default)), attrs = ['bold']))
print(cl('.......................................', attrs = ['bold']))
print(cl('Number of Non-Default cases are {}'.format(len(non_default)), attrs = ['bold']))
print(cl('.......................................', attrs = ['bold']))
print(cl('Percentage of Default cases is {:.0%}'.format(len(default)/len(non_default)), attrs = ['bold']))
print(cl('.......................................', attrs = ['bold']))

sb.countplot(df['Defaulted'], 
             palette = ['coral', 'deepskyblue'], 
             edgecolor = 'darkgrey')
plt.title('Credit card default cases(1) and non-default cases(0)', 
          fontsize = 18)
plt.xlabel('Default value', 
           fontsize = 16)
plt.ylabel('Number of People', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)

plt.savefig('default_cases.png')
plt.show()

# Age vs Income

sb.scatterplot('Age', 'Income', 
               data = df, 
               color = 'deepskyblue', 
               s = 150, 
               alpha = 0.6, 
               edgecolor = 'b')
plt.title('AGE / INCOME', 
          fontsize = 18)
plt.xlabel('Age', 
           fontsize = 16)
plt.ylabel('Income', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)

plt.savefig('age_income.png')
plt.show()

# Years Employed vs Income

area = df.DebtIncomeRatio **2

sb.scatterplot('Years Employed', 'Income', 
               data = df, 
               s = area, 
               alpha = 0.6, 
               edgecolor = 'white', 
               hue = 'Defaulted', 
               palette = 'spring')
plt.title('YEARS EMPLOYED / INCOME', 
          fontsize = 18)
plt.xlabel('Years Employed', 
           fontsize = 16)
plt.ylabel('Income', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.legend(loc = 'upper left', fontsize = 14)

plt.savefig('y_income.png')
plt.show()

# DATA PROCESSING

X = df.values
X = np.nan_to_num(X)

sc = StandardScaler()

cluster_data = sc.fit_transform(X)
print(cl('Cluster data samples : ', attrs = ['bold']), cluster_data[:5])

# MODELING

clusters = 3
model = KMeans(init = 'k-means++', 
               n_clusters = clusters, 
               n_init = 12)
model.fit(X)

labels = model.labels_
print(cl(labels[:100], attrs = ['bold']))

df['cluster_num'] = labels
print(df.head())

print(df.groupby('cluster_num').mean())

area = np.pi * (df.Edu) ** 4

sb.scatterplot('Age', 'Income', 
               data = df, 
               s = area, 
               hue = 'cluster_num', 
               palette = 'spring', 
               alpha = 0.6, 
               edgecolor = 'darkgrey')
plt.title('AGE / INCOME (CLUSTERED)', 
          fontsize = 18)
plt.xlabel('Age', 
           fontsize = 16)
plt.ylabel('Income', 
           fontsize = 16)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 14)
plt.legend(loc = 'upper left', fontsize = 14)

plt.savefig('c_age_income.png')
plt.show()

fig = plt.figure(1)
plt.clf()
ax = Axes3D(fig, 
            rect = [0, 0, .95, 1], 
            elev = 48, 
            azim = 134)

plt.cla()
ax.scatter(df['Edu'], df['Age'], df['Income'], 
           c = df['cluster_num'], 
           s = 200, 
           cmap = 'spring', 
           alpha = 0.5, 
           edgecolor = 'darkgrey')
ax.set_xlabel('Education', 
              fontsize = 16)
ax.set_ylabel('Age', 
              fontsize = 16)
ax.set_zlabel('Income', 
              fontsize = 16)

plt.savefig('3d_plot.png')
plt.show()

 

3 comments

3 comentarios


saravanakumaar.a
saravanakumaar.a
22 nov 2020

You’re getting better every day

Me gusta

209847 VEERABADRAN V
209847 VEERABADRAN V
21 nov 2020

Good Nikhil

Me gusta

Rathinagiri Subbiah
Rathinagiri Subbiah
21 nov 2020

Good. Nice to see the usage of 3D Plot.

Me gusta
bottom of page