This is Part 3 of my decision trees series. This time around we are going to code a decision tree in Python. So I’m going to try to make this code as understandable as possible, but if you are not familiar with Object Oriented Programming (OOP) or recursion you might have a tougher time.
To make data handling easier, we are going to be using the wonderful pandas
package, so if you don’t know how to use it I highly recommend you learn to by reading their intro to pandas. (there are also very good tutorials by Chris Albon but they are more focused on a specific feature), but I’ll just quickly go over the most important points. In pandas
tabular data is stored in a DataFrame
and these objects allow us to have named columns and rows and to easily subset data with these names. Ok so let’s create a DataFrame
from a subset of our iris data.


Ok so we have our DataFrame
, and now we can select just the species column for example with:


Or we can get the samples which have $$sepal\ width \leq 3.1$$:


So it is easy to see how this makes understanding what the code is doing easier and that’s why I’m using DataFrame
s instead of arrays and indices. This is however far from all pandas
can do (duh…), and it’s just the basics of the basics.
So what is a tree?
A tree is a set of nodes, so we need to define a Node
object which will store some properties about the node, as well as it’s left and right children nodes.
What does a node of our decision tree need?
A pandas DataFrame to represent the training data at this node, the target feature (in the case of our iris example that would be the “species” feature). We need to have a left and right child as well, and if the node is a leaf node it needs to be able to give a prediction. This gives us the following Node
class definition, where we feed the dataframe as well as the target feature upon initialization.


Ok so now we have our Node
class, how do we represent a tree? where do we keep all the nodes we will need. Well actually we do not need to create a Tree
class:
since each node stores its left and right children, we can access any and all nodes of the tree from the root node. So the whole tree can just be represented by the root node. For example if we want to get the versicolor node in our simple tree we can access it from the root node with: root.right.left
Ok, so we can just concentrate ourselves on the Node
class. So if you remember how the algorithm works (if not here is the post where I explain it), we are going to need a way to find splits in our data. As we said in previous parts, there are categorical and numerical splits, so we need a way to determine if a feature is categorical or numerical, fortunately pandas
has us covered, and we can make a simple function:


So now that we can distinguish the two we can write this function that gets all possible splits in our dataset, and returns them as a dictionary


As you can see gets all the features in our dataset (except the outcome), loops over them and checks if the feature is categorical or numerical. Depending on th feature type it calls a given splitting method. So how do we get our categorical and numerical splits? For the numerical splits we have this function :


This returns all possible numerical splits in a dictionary where the key is a tuple
of the feature name, the value on which the split is done and the type of split, and as a value the data that goes to the left side of the split (the data that respects the split condition).
For categorical features we are not going to follow exactly what I said in part 2, indeed the total number of splits is: \(2^{k1}  1\), with \(k\) the possible values of our feature, this can get huge very quickly. For example, a categorical feature with 25 levels (25 brands of car, or 25 different languages, whatever…), which can be easily attained in some datasets, would result in \(33554432\) splits to evaluate, and that’s just for one feature in one node. So this can get out of hand very quickly and slow our program to a crawl. So I’m going to make an executive decision here and say we will only consider splits made by a single level, for example brand = Ford
and eliminate all splits made by combinations of levels: (brand = Ford) or (brand = Chevrolet)
. This brings us back to a nice \(k\) possible splits. So we can add this method to get categorical splits:


Ok all done, so now at a single node we can get all the dataset splits we want to evaluate by calling the get_splits()
method.
Next we need a way to calculate the impurity of a split, in our case is going to be the Gini index. For reminder the Gini index \(G\) for a node \(t\) is defined as:
$$ G(t) = 1  \sum^k_{i=1} p_i^2 $$
Where \(p_i\) is the proportion of samples of class \(i\) in the node data, and \(k\) is the number of different classes. So let’s add a method to our Node
class to do just that:


value_counts()
is a pandas
method that gets all unique values in a given column and returns their counts, the normalize
option makes it return proportions instead of counts. The next step is computing the decrease in impurity \(\Delta i\) (see part 2 for formula).


Here subset
is the data in the entries of the dictionary given by get_splits()
, so it’s just the left side of the splits. To get the right side we take the whole data of the node and get rid of all the rows that are in the left split (subset
).
So now we can get the best split by looping over all possible splits, ad returning the one with the highest value of \(\Delta i\):


Ok so now for the fun part, we are going to build the tree recursively. To do that we choose the best split at our root node, and then apply the split function on each of these subtrees, and in turn each of the splits in these split will also be split… And this needs to keep happening until some conditions are met: the stop condition.
So when do we want to stop splitting the data at a given node? Well the simplest answer would be to stop when the node is pure (or when there are no more possible splits), so it contains data points of only one class. However since, as we said in previous parts, we want to avoid overfitting we will also add some other stop conditions:
To implement these two stopping options we need to add more parameters to our Node
class, so let’s modify ou __init__
method:


Ok so here I have just added parameters to determine when we want our tree to stop splitting, as well as a level
value that is just going to allow us to keep track of the depth of a given node in the tree. Ok so now we have these parameters we need to implement methods that allow us to check if any of the stopping conditions are met:


So if any of these methods return true we will stop splitting. Ok so now we have defined our stop condition we can write up our recursive splitting method:


Ok so it might seem like a long function but it is actually quite simple, We just keep splitting the data with the best possible split (maximizing \(\Delta i\)), and if one of our stop conditions is met we get the prediction that this node will make: the most frequent class in the node.
All right we’re done with the important bits, let’s test our program out, and see what kind of trees we get, to be able to see what tree we have I blatantly ripped off this StackOverflow answer which gives us super nice trees! And I added a value
property for my nodes where I put a string describing the split if the node is a split node, and the predicted class if the node is a leaf node.
and if we try out our code with the iris data we get:


Hey that tree looks super familiar, yay it’s the exact same one than the in previous parts, our method worked! how about if we want a deeper tree?


We get a tree that’s one level deeper. So everything seems to be working fine. However in our iris dataset we only have numerical data, lengths and widths, so we don’t really know if our tree building method works with categorical data. So to do this I’m going to use the golfing dataset which has a certain number of features, and the target value is if a game of golf is played or not. This dataset is very small so I can show you all of it here:
id  outlook  temperature  humidity  windy  play 

1  sunny  85  85  FALSE  no 
2  sunny  80  90  TRUE  no 
3  overcast  83  86  FALSE  yes 
4  rainy  70  96  FALSE  yes 
5  rainy  68  80  FALSE  yes 
6  rainy  65  70  TRUE  no 
7  overcast  64  65  TRUE  yes 
8  sunny  72  95  FALSE  no 
9  sunny  69  70  FALSE  yes 
10  rainy  75  80  FALSE  yes 
11  sunny  75  70  TRUE  yes 
12  overcast  72  90  TRUE  yes 
13  overcast  81  75  FALSE  yes 
14  rainy  71  91  TRUE  no 
That’s it, that’s the whole dataset, but you see here we have a nice mix of categorical and numerical data. Ok so let’s see how our CART implementation handles this:


Yay everything works!
You might have noticed that we only have classification trees in this example, and you’d be right. I haven’t implemented the regression part yet because I’m too lazy but it would be exactly the same, but you would need to add an RSS function that you could plug in the get_delta_i()
method and in the split()
method, when a leaf node is reached set the prediction value to the mean of the dataset outcomes instead of the most frequent one. So I’ll put it in eventually but I won’t make a separate post on that. All of the code is on my github so you can play with it if you want.
One last thing, we haven’t implemented the full CART algorithm because there is no pruning method to avoid overfitting, but this will come in a future part, so stay tuned!.