3
votes

I have a dependent variable to classify by a decision tree. It's composed by three categories of frequences: 738 (19%), 426 (15%) and 1800 (66%). As you imagine the predicted category is always the third one, but the purpose of the tree is descriptive so it does not actually matter. The thing is, when plotting a tree by the ctree() function (package partykit) the terminal nodes display histograms showing the probability of occurrence of the three classes. I need to modify this output: I would like to obtain the proportions of occurrence of each class within the terminal node with respect to the class' absolute frequency. For example, which percentage of the 738 participants in class1 belongs to a certain terminal node? Each terminal node would display this values for all the three classes that compose the dependent variable.

Bellow a plot of the tree, which by default reports the prevalence of each class within the terminal nodes.

1

1 Answers

8
votes

You can always define your own panel function to draw what goes into each terminal panel window. If you know a little bit about grid graphics and you look at how the current terminal panel functions are defined you will see how this works.

One panel function that ought to do what you want is node_terminal() in the partykit package (the much improved re-implementation of the old party package). However, because ctree() does not store its predictions in each terminal node, the node_terminal() function cannot do this out of the box at the moment. I'll try to improve the implementation in future versions so that this can be facilitated. Below is a somewhat involved example that should do what you want, I hope.

First, we fit a classification tree using the iris data (for a simple reproducible example):

library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
## 
## Fitted party:
## [1] root
## |   [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## |   [3] Petal.Length > 1.9
## |   |   [4] Petal.Width <= 1.7
## |   |   |   [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## |   |   |   [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## |   |   [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
## 
## Number of inner nodes:    3
## Number of terminal nodes: 4

Then we compute a table of predicted probabilities for each terminal node:

(pred <- aggregate(predict(ct, type = "prob"),
  list(predict(ct, type = "node")), FUN = mean))
##   Group.1 setosa versicolor  virginica
## 1       2      1 0.00000000 0.00000000
## 2       5      0 0.97826087 0.02173913
## 3       6      0 0.50000000 0.50000000
## 4       7      0 0.02173913 0.97826087

Then comes the not so obvious part: We want to include these predicted probabilities in the terminal nodes of the tree itself. For this, we coerce the recursive node structure to a flat list, insert the predictions (suitably formatted), and convert the list back to the node structure:

ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
  ct_node[[pred[i,1]]]$info$prediction <- paste(
    format(names(pred)[-1]),
    format(round(pred[i, -1], digits = 3), nsmall = 3)
  )
}
ct$node <- as.partynode(ct_node)

Then, we can easily draw a picture of the tree with the node_terminal panel function and inserting our pre-formatted predictions:

plot(ct, terminal_panel = node_terminal, tp_args = list(
  FUN = function(node) c("Predictions", node$prediction)))

custom tree

EDIT: The coercing back and forth between a list and a party is actually already implemented in the package...I just forgot about it ;-) If you do

st <- as.simpleparty(ct)

then the resulting party has in each node more detailed information about the predictions etc. For example, the $distribution then contains the absolute frequencies for each response level. This can easily be formatted as before

pred <- function(i) {
  tab <- i$distribution
  tab <- round(prop.table(tab), 3)
  tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
  c("Predictions", tab)
}

And this can be passed to node_terminal to essentially create the plot above. You might want to change drop = FALSE to drop = TRUE if you want all terminal nodes to be displayed in the bottom row.

plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))