top of page
  • Nikhil Adithyan

Building and Visualizing Decision Tree in Python

Updated: Apr 21, 2021

Learn to build and visualize a Decision tree model with scikit-learn in Python

Decision Tree

Decision trees are the building blocks of some of the most powerful supervised learning methods that are used today. A decision tree is basically a binary tree flowchart where each node splits a group of observations according to some feature variable. The goal of a decision tree is to split your data into groups such that every element in one group belongs to the same category. Decision trees can also be used to approximate a continuous target variable. In that case, the tree will make splits such that each group has the lowest mean squared error.

One of the great properties of decision trees is that they are very easily interpreted. You do not need to be familiar at all with machine learning techniques to understand what a decision tree is doing. Decision tree graphs are very easily interpreted.

Pros and Cons

Pros of Decision tree methods are:

  • Decision trees are able to generate understandable rules.

  • Decision trees perform classification without requiring much computation.

  • Decision trees are able to handle both continuous and categorical variables.

  • Decision trees provide a clear indication of which fields are most important for prediction or classification.

Cons of Decision tree methods are:

  • Decision trees are less appropriate for estimation tasks where the goal is to predict the value of a continuous attribute.

  • Decision trees are prone to errors in classification problems with many classes and relatively small numbers of training examples.

  • A decision tree can be computationally expensive to train. The process of growing a decision tree is computationally expensive. At each node, each candidate splitting field must be sorted before its best split can be found. In some algorithms, combinations of fields are used and a search must be made for optimal combining weights. Pruning algorithms can also be expensive since many candidate sub-trees must be formed and compared.

Python for Decision Tree

Python is a general-purpose programming language and offers data scientists powerful machine learning packages and tools. In this article, we will be building our Decision tree model using python’s most famous machine learning package, ‘scikit-learn’. We will be creating our model using the ‘DecisionTreeClassifier’ algorithm provided by scikit-learn then, visualize the model using the ‘plot_tree’ function. Let’s do it!

Step-1: Importing the packages

Our primary packages involved in building our model are pandas, scikit-learn, and NumPy. Follow the code to import the required packages in python.

Python Implementation:

After importing all the required packages for building our model, it’s time to import the data and do some EDA on it.

Step-2: Importing data and EDA

In this step, we will be utilizing the ‘Pandas’ package available in python to import and do some EDA on it. The dataset we will be using to build our decision tree model is a drug dataset that is prescribed to patients based on certain criteria. Let’s import the data in python!

Python Implementation:


   Age Sex      BP Cholesterol  Na_to_K   Drug
0   23   F    HIGH        HIGH   25.355  drugY
1   47   M     LOW        HIGH   13.093  drugC
2   47   M     LOW        HIGH   10.114  drugC
3   28   F  NORMAL        HIGH    7.798  drugX
4   61   F     LOW        HIGH   18.043  drugY

Now we have a clear idea of our dataset. After importing the data, let’s get some basic information on the data using the ‘info’ function. The information provided by this function includes the number of entries, index number, column names, non-null values count, attribute type, etc.

Python Implementation:


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 6 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   Age          200 non-null    int64  
 1   Sex          200 non-null    object 
 2   BP           200 non-null    object 
 3   Cholesterol  200 non-null    object 
 4   Na_to_K      200 non-null    float64
 5   Drug         200 non-null    object 
dtypes: float64(1), int64(1), object(4)
memory usage: 9.5+ KB

Step-3: Data Processing

We can see that attributes like Sex, BP, and Cholesterol are categorical and object type in nature. The problem is, the decision tree algorithm in scikit-learn does not support X variables to be ‘object’ type in nature. So, it is necessary to convert these ‘object’ values into ‘binary’ values. Let’s do it in python!

Python Implementation:


Age  Sex  BP  Cholesterol  Na_to_K   Drug
0     23    1   2            1   25.355  drugY
1     47    1   0            1   13.093  drugC
2     47    1   0            1   10.114  drugC
3     28    1   1            1    7.798  drugX
4     61    1   0            1   18.043  drugY
..   ...  ...  ..          ...      ...    ...
195   56    1   0            1   11.567  drugC
196   16    1   0            1   12.006  drugC
197   52    1   1            1    9.894  drugX
198   23    1   1            1   14.020  drugX
199   40    1   0            1   11.349  drugX

[200 rows x 6 columns]

We can observe that all the ‘object’ values are processed into ‘binary’ values to represent categorical data. For example, in the Cholesterol attribute, values showing ‘LOW’ are processed to 0 and ‘HIGH’ to be 1. Now we are ready to create the dependent variable and independent variable out of our data.

Step-4: Splitting the data

After processing our data to be of the right structure, we are now set to define the ‘X’ variable or the independent variable and the ‘Y’ variable or the dependent variable. Let’s do it in python!

Python Implementation:


X variable samples : [[ 1.     2.    23.     1.    25.355]
 [ 1.     0.    47.     1.    13.093]
 [ 1.     0.    47.     1.    10.114]
 [ 1.     1.    28.     1.     7.798]
 [ 1.     0.    61.     1.    18.043]]
Y variable samples : ['drugY' 'drugC' 'drugC' 'drugX' 'drugY']

We can now split our data into a training set and testing set with our defined X and Y variables by using the ‘train_test_split’ algorithm in scikit-learn. Follow the code to split the data in python.

Python Implementation:


X_train shape : (160, 5)
X_test shape : (40, 5)
y_train shape : (160,)
y_test shape : (40,)

Now we have all the components to build our decision tree model. So, let’s proceed to build our model in python.

Step-5: Building the model & Predictions

Building a decision tree can be feasibly done with the help of the ‘DecisionTreeClassifier’ algorithm provided by the scikit-learn package. After that, we can make predictions of our data using our trained model. Finally, the precision of our predicted results can be calculated using the ‘accuracy_score’ evaluation metric. Let’s do this process in python!

Python Implementation:


Accuracy of the model is 88%

In the first step of our code, we are defining a variable called the ‘model’ variable in which we are storing the DecisionTreeClassifier model. Next, we are fitting and training the model using our training set. After that, we defined a variable called the ‘pred_model’ variable in which we stored all the predicted values by our model on the data. Finally, we calculated the precision of our predicted values to the actual values which resulted in 88% accuracy.

Step-6: Visualizing the model

Now that we have our decision tree model and let’s visualize it by utilizing the ‘plot_tree’ function provided by the scikit-learn package in python. Follow the code to produce a beautiful tree diagram out of your decision tree model in python.

Python Implementation:



There are a lot of techniques and other algorithms used to tune decision trees and to avoid overfitting, like pruning. Although, decision trees are usually unstable which means a small change in the data can lead to huge changes in the optimal tree structure yet their simplicity makes them a strong candidate for a wide range of applications. Before neural networks became popular, decision trees were the state of the art algorithm in Machine Learning. Several other ensemble models like Random Forests are much more powerful than the vanilla decision tree. Decision trees are powerful because of their simplicity and interpretability. Decision trees and random forests are highly used in user signup modeling, credit scoring, failure prediction, medical diagnostics, etc. in the industry. With that, we come to an end and if you forget to follow any of the coding parts, don’t worry I’ve provided the full code for this article.

Happy Machine Learning!

Full code:

bottom of page