You can always define your own panel function to draw what goes into each terminal panel window. If you know a little bit about grid
graphics and you look at how the current terminal panel functions are defined you will see how this works.
One panel function that ought to do what you want is node_terminal()
in the partykit
package (the much improved re-implementation of the old party
package). However, because ctree()
does not store its predictions in each terminal node, the node_terminal()
function cannot do this out of the box at the moment. I'll try to improve the implementation in future versions so that this can be facilitated. Below is a somewhat involved example that should do what you want, I hope.
First, we fit a classification tree using the iris
data (for a simple reproducible example):
library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
##
## Fitted party:
## [1] root
## | [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## | [3] Petal.Length > 1.9
## | | [4] Petal.Width <= 1.7
## | | | [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## | | | [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## | | [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
##
## Number of inner nodes: 3
## Number of terminal nodes: 4
Then we compute a table of predicted probabilities for each terminal node:
(pred <- aggregate(predict(ct, type = "prob"),
list(predict(ct, type = "node")), FUN = mean))
## Group.1 setosa versicolor virginica
## 1 2 1 0.00000000 0.00000000
## 2 5 0 0.97826087 0.02173913
## 3 6 0 0.50000000 0.50000000
## 4 7 0 0.02173913 0.97826087
Then comes the not so obvious part: We want to include these predicted probabilities in the terminal nodes of the tree itself. For this, we coerce the recursive node structure to a flat list, insert the predictions (suitably formatted), and convert the list back to the node structure:
ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
ct_node[[pred[i,1]]]$info$prediction <- paste(
format(names(pred)[-1]),
format(round(pred[i, -1], digits = 3), nsmall = 3)
)
}
ct$node <- as.partynode(ct_node)
Then, we can easily draw a picture of the tree with the node_terminal
panel function and inserting our pre-formatted predictions:
plot(ct, terminal_panel = node_terminal, tp_args = list(
FUN = function(node) c("Predictions", node$prediction)))
EDIT: The coercing back and forth between a list
and a party
is actually already implemented in the package...I just forgot about it ;-) If you do
st <- as.simpleparty(ct)
then the resulting party
has in each node more detailed information about the predictions etc. For example, the $distribution
then contains the absolute frequencies for each response level. This can easily be formatted as before
pred <- function(i) {
tab <- i$distribution
tab <- round(prop.table(tab), 3)
tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
c("Predictions", tab)
}
And this can be passed to node_terminal
to essentially create the plot above. You might want to change drop = FALSE
to drop = TRUE
if you want all terminal nodes to be displayed in the bottom row.
plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))