0
votes

I'm moving my first steps in machine learning in general and with R in particular. I used python's sklearn before but I'm completely new with R. For a university project I'm trying a random forest on a gene expression dataset for educational purposes. I'm trying to predict mental disorders (bipolar disorder, depression or schizophrenia) using gene expression of various brain cells. My script currently looks like this:

library(randomForest)

train_ind <- sample.int(n = nrow(GSMdata),
                         size = floor(0.75*nrow(GSMdata)),
                         replace = F)
RFtrainSet <- data[,train_ind]
RFtestSet <- data[,-train_ind]
RFtrainLabel <- GSMdata$Disease_State[train_ind]
RFtestLabel <- GSMdata$Disease_State[-train_ind]

RFmodel <- randomForest(x = t(RFtrainSet),
                        y = RFtrainLabel,
                        ntree = 100)

table(RFtestLabel, predict(object = RFmodel, 
                           newdata = t(RFtestSet)))

Where data is a large matrix object and GSMdata is a dataframe with the feature of each sample (each column in the matrx represent the gene expression of each sample). The output of the table function looks like this:

RFtestLabel                 bipolar disorder control major depressive disorder schizophrenia
  bipolar disorder                         0       7                         6             7
  control                                  0       7                         6             0
  major depressive disorder                0       5                         2             2
  schizophrenia                            0       1                         7             2

Frequently when I sample the data a class does not appear in the test dataset as you can see in the sample above. Is this a problem? If yes, is there a function that helps me in having homogeneous test samples?

Data Example

data matrix:

          GSM1304852  GSM1304853  GSM1304854 GSM1304855 GSM1304856
1007_s_at  2.3945368  2.27518369  2.16116298  1.9641833  2.1322526
1053_at    0.1051084  0.06160802  0.34217618  0.3593916  0.2235696
117_at    -0.4597124 -0.52310349 -0.44360591 -0.6370277 -0.3511470
121_at     0.9333566  1.13180904  0.99756999  1.0079778  0.9720455
1255_g_at -0.2399138  0.10112324 -0.04087979 -0.2185137 -0.2991786

GSMdata example:

                   title geo_accession Age    Disease_State Gender  pH  PMI Race RIN      tissue
GSM1304852 bipolar_hip_10    GSM1304852  52 bipolar disorder      M 6.7 23.5    W 6.3 hippocampus
GSM1304853 bipolar_hip_11    GSM1304853  50 bipolar disorder      F 6.4 11.7    W 6.8 hippocampus
GSM1304854 bipolar_hip_12    GSM1304854  28 bipolar disorder      F 6.3 22.3    W 7.7 hippocampus
GSM1304855 bipolar_hip_13    GSM1304855  55 bipolar disorder      F 6.4 17.5    W 7.6 hippocampus
GSM1304856 bipolar_hip_14    GSM1304856  58 bipolar disorder      M 6.8 27.7    W 7.0 hippocampus
2
I would have thought the bigger problem would be where the test set has classes not in the train set. You could try sampling within classes.IRTFM

2 Answers

1
votes

Here's a quick dplyr solution to sample within class, no special function needed. I'm using the iris dataset as an example, but you can quickly adapt it to your data.

library(dplyr)
data(iris)
labels <- iris %>% dplyr::select(Species) %>% 
    sample_frac(1) %>% 
    group_by(Species) %>% 
    mutate(set = rep(c(rep("train",3),"test"), length.out=n()))

table(labels$Species, labels$set)

             test train
  setosa       12    38
  versicolor   12    38
  virginica    12    38

Also, I would recommend the ranger random forest package, as it is faster.

0
votes

One way to do this is through using stratified (from splitstackshape package) and using sqldf (to make SQL queries) as below:

set.seed(1231) 
data(iris)

data <- iris
data$ID <- seq.int(nrow(data)) #Why? remove it and run this again without this bit and you will see the difference.

# making stratified train samples
m_trn <- data.frame((splitstackshape::stratified(data, "Species", 0.5))) #0.5 is percent of training data in each class
m_tst <- (sqldf::sqldf('SELECT * FROM data EXCEPT SELECT * FROM m_trn'))