0
votes

I've used part to build a decision tree on a dataset that has categorical variables with hundreds of levels. The tree splits these variables based on select values of the variable. I would like to examine the labels on which the split is made.

If I just run the decision tree result, the display listing the splits in the console gets truncated and either way, it is not in an easily-interpretable format (separated by commas). Is there a way to access this as an R object? I'm open to using another package to build the tree.

1
If you haven't already, I'd start by searching through the results of str(myTree). - lmo
There is almost certainly a way to get at the information which is being printed in truncated form. Looking at the manual is probably a good way to find out how. - John Coleman
try the rpart.plot package - knb
rpart.utils package, function rpart.subrules.table will help you out - phiver

1 Answers

1
votes

One issue here is that some of the functions in the rpart package are not exported. It appears you're looking to capture the output of the function rpart:::print.rpart. So, beginning with a reproducible example:

set.seed(1)
df1 <- data.frame(y=rbinom(n=100, size=1, prob=0.5),
                  x1=rbinom(n=100, size=1, prob=0.25),
                  x2=rbinom(n=100, size=1, prob=0.75))
(r1 <- rpart(y ~ ., data=df1))

giving

n= 100 

node), split, n, deviance, yval
      * denotes terminal node

1) root 100 24.960000 0.4800000  
  2) x1< 0.5 78 19.179490 0.4358974  
    4) x2>=0.5 66 15.954550 0.4090909 *
    5) x2< 0.5 12  2.916667 0.5833333 *
  3) x1>=0.5 22  5.090909 0.6363636  
    6) x2< 0.5 7  1.714286 0.4285714 *
    7) x2>=0.5 15  2.933333 0.7333333 *

Now, looking at rpart:::print.rpart, we see a call to rpart:::labels.rpart, giving us the splits (or names of the 'rows' in the output above). The value of n, deviance, yval and more are stored in r1$frame, which can be seen by inspecting the output from unclass(r1).

Thus we could extract the above with

(df2 <- data.frame(split=rpart:::labels.rpart(r1), n=r1$frame$n))

giving

    split   n
1    root 100
2 x1< 0.5  78
3 x2>=0.5  66
4 x2< 0.5  12
5 x1>=0.5  22
6 x2< 0.5   7
7 x2>=0.5  15