Updated: Apr 21
Learn to build and visualize a Decision tree model with scikit-learn in Python
Decision trees are the building blocks of some of the most powerful supervised learning methods that are used today. A decision tree is basically a binary tree flowchart where each node splits a group of observations according to some feature variable. The goal of a decision tree is to split your data into groups such that every element in one group belongs to the same category. Decision trees can also be used to approximate a continuous target variable. In that case, the tree will make splits such that each group has the lowest mean squared error.
One of the great properties of decision trees is that they are very easily interpreted. You do not need to be familiar at all with machine learning techniques to understand what a decision tree is doing. Decision tree graphs are very easily interpreted.
Pros and Cons
Pros of Decision tree methods are:
Decision trees are able to generate understandable rules.
Decision trees perform classification without requiring much computation.
Decision trees are able to handle both continuous and categorical variables.
Decision trees provide a clear indication of which fields are most important for prediction or classification.
Cons of Decision tree methods are:
Decision trees are less appropriate for estimation tasks where the goal is to predict the value of a continuous attribute.
Decision trees are prone to errors in classification problems with many classes and relatively small numbers of training examples.
A decision tree can be computationally expensive to train. The process of growing a decision tree is computationally expensive. At each node, each candidate splitting field must be sorted before its best split can be found. In some algorithms, combinations of fields are used and a search must be made for optimal combining weights. Pruning algorithms can also be expensive since many candidate sub-trees must be formed and compared.
Python for Decision Tree
Python is a general-purpose programming language and offers data scientists powerful machine learning packages and tools. In this article, we will be building our Decision tree model using python’s most famous machine learning package, ‘scikit-learn’. We will be creating our model using the ‘DecisionTreeClassifier’ algorithm provided by scikit-learn then, visualize the model using the ‘plot_tree’ function. Let’s do it!
Step-1: Importing the packages
Our primary packages involved in building our model are pandas, scikit-learn, and NumPy. Follow the code to import the required packages in python.
After importing all the required packages for building our model, it’s time to import the data and do some EDA on it.
Step-2: Importing data and EDA
In this step, we will be utilizing the ‘Pandas’ package available in python to import and do some EDA on it. The dataset we will be using to build our decision tree model is a drug dataset that is prescribed to patients based on certain criteria. Let’s import the data in python!
Age Sex BP Cholesterol Na_to_K Drug 0 23 F HIGH HIGH 25.355 drugY 1 47 M LOW HIGH 13.093 drugC 2 47 M LOW HIGH 10.114 drugC 3 28 F NORMAL HIGH 7.798 drugX 4 61 F LOW HIGH 18.043 drugY
Now we have a clear idea of our dataset. After importing the data, let’s get some basic information on the data using the ‘info’ function. The information provided by this function includes the number of entries, index number, column names, non-null values count, attribute type, etc.
<class 'pandas.core.frame.DataFrame'> RangeIndex: 200 entries, 0 to 199 Data columns (total 6 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 200 non-null int64 1 Sex 200 non-null object 2 BP 200 non-null object 3 Cholesterol 200 non-null object 4 Na_to_K 200 non-null float64 5 Drug 200 non-null object dtypes: float64(1), int64(1), object(4) memory usage: 9.5+ KB
Step-3: Data Processing
We can see that attributes like Sex, BP, and Cholesterol are categorical and object type in nature. The problem is, the decision tree algorithm in scikit-learn does not support X variables to be ‘object’ type in nature. So, it is necessary to convert these ‘object’ values into ‘binary’ values. Let’s do it in python!
Age Sex BP Cholesterol Na_to_K Drug 0 23 1 2 1 25.355 drugY 1 47 1 0 1 13.093 drugC 2 47 1 0 1 10.114 drugC 3 28 1 1 1 7.798 drugX 4 61 1 0 1 18.043 drugY .. ... ... .. ... ... ... 195 56 1 0 1 11.567 drugC 196 16 1 0 1 12.006 drugC 197 52 1 1 1 9.894 drugX 198 23 1 1 1 14.020 drugX 199 40 1 0 1 11.349 drugX [200 rows x 6 columns]
We can observe that all the ‘object’ values are processed into ‘binary’ values to represent categorical data. For example, in the Cholesterol attribute, values showing ‘LOW’ are processed to 0 and ‘HIGH’ to be 1. Now we are ready to create the dependent variable and independent variable out of our data.
Step-4: Splitting the data
After processing our data to be of the right structure, we are now set to define the ‘X’ variable or the independent variable and the ‘Y’ variable or the dependent variable. Let’s do it in python!
X variable samples : [[ 1. 2. 23. 1. 25.355] [ 1. 0. 47. 1. 13.093] [ 1. 0. 47. 1. 10.114] [ 1. 1. 28. 1. 7.798] [ 1. 0. 61. 1. 18.043]] Y variable samples : ['drugY' 'drugC' 'drugC' 'drugX' 'drugY']
We can now split our data into a training set and testing set with our defined X and Y variables by using the ‘train_test_split’ algorithm in scikit-learn. Follow the code to split the data in python.
X_train shape : (160, 5) X_test shape : (40, 5) y_train shape : (160,) y_test shape : (40,)
Now we have all the components to build our decision tree model. So, let’s proceed to build our model in python.
Step-5: Building the model & Predictions
Building a decision tree can be feasibly done with the help of the ‘DecisionTreeClassifier’ algorithm provided by the scikit-learn package. After that, we can make predictions of our data using our trained model. Finally, the precision of our predicted results can be calculated using the ‘accuracy_score’ evaluation metric. Let’s do this process in python!
Accuracy of the model is 88%
In the first step of our code, we are defining a variable called the ‘model’ variable in which we are storing the DecisionTreeClassifier model. Next, we are fitting and training the model using our training set. After that, we defined a variable called the ‘pred_model’ variable in which we stored all the predicted values by our model on the data. Finally, we calculated the precision of our predicted values to the actual values which resulted in 88% accuracy.
Step-6: Visualizing the model
Now that we have our decision tree model and let’s visualize it by utilizing the ‘plot_tree’ function provided by the scikit-learn package in python. Follow the code to produce a beautiful tree diagram out of your decision tree model in python.
There are a lot of techniques and other algorithms used to tune decision trees and to avoid overfitting, like pruning. Although, decision trees are usually unstable which means a small change in the data can lead to huge changes in the optimal tree structure yet their simplicity makes them a strong candidate for a wide range of applications. Before neural networks became popular, decision trees were the state of the art algorithm in Machine Learning. Several other ensemble models like Random Forests are much more powerful than the vanilla decision tree. Decision trees are powerful because of their simplicity and interpretability. Decision trees and random forests are highly used in user signup modeling, credit scoring, failure prediction, medical diagnostics, etc. in the industry. With that, we come to an end and if you forget to follow any of the coding parts, don’t worry I’ve provided the full code for this article.
Happy Machine Learning!