Previous chapter
ModelClassification Trees
Next chapter

Introduction

Classification trees are partitioned recursively and splits are evaluated based on either the Gini index or entropy.

The Gini index is defined by \[G = \sum_{k=1}^K \hat{p}_{mk}(1 − \hat{p}_{mk})\] a measure of total variance across the K classes. Here \(\hat{p}_{mk}\) represents the proportion of training observations in the \(m^{th}\) region that are from the \(k^{th}\) class

It is not hard to see that the Gini index takes on a small value if all of the \(p_{mk}\)’s are close to zero or one. For this reason the Gini index is referred to as a measure of node purity—a small value indicates that a node contains predominantly observations from a single class.

An alternative to the Gini index is entropy, given by

\[ D = − \sum_{k=1}^K \hat{p}_{mk}\ log\ \hat{p}_{mk}\]

Since \(0 ≤\hat{p}_{mk} ≤ 1\), it follows that \(0 ≤ − \hat{p}_{mk}\ log\ \hat{p}_{mk}\). One can show that the entropy will take on a value near zero if the \(\hat{p}_{mk}\)’s are all near zero or near one. Therefore, like the Gini index, the entropy will take on a small value if the mth node is pure. In fact, it turns out that the Gini index and the entropy are quite similar numerically.

Fitting Classification Trees

We now fit a data set based on children’s carseats sales data. We will attempt to predict Sales (child car seat sales) in 400 locations based on a number of predictors. We first transform the Sales numeric variable into a factor indicating whether sales where high.

We see that the training error rate is 9 %. For classification trees, the deviance reported in the output of summary() is given by

\[-2 \sum_{m} \sum_{k} n_{mk}\ log\ \hat{p}_{mk}\]

Plotting Classification Trees

Now, let’s plot the tree stored in tree.carseats:

Using rpart() instead of tree() and the package rpart.plot we can get a nicer looking tree output:

Fitting Classification Tree using a Test Set

Instead of reporting the in-sample performance with an accuracy of 91% we should now estimate the out-of-sample performance:

Fitting a Classification Tree using Pruning and Cross Validation

Next, we consider whether pruning the tree might lead to improved results. We use the argument FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.

Apply Pruning

We now apply the prune.misclass() function in order to prune the tree to prune. obtain the nine-node tree.

Evaluate Pruned Tree

How well does this pruned tree perform on the test data set? Once again, we apply the predict() function.

Increase Number of best

Let’s increase the value of best to obtain a larger pruned tree - did we increase theclassification accuracy?