0
votes

I was wondering how to measure the performance of prediction (using the test dataset) of an mlr3 model? For example, if I create a knn model using mlr3 like so:

library("mlr3")
library("mlr3learners")
 
# get data and split into training and test
aq <- na.omit(airquality)
train <- sample(nrow(aq), round(.7*nrow(aq))) # split 70-30
aqTrain <- aq[train, ]
aqTest <- aq[-train, ]


# create model
aqT <- TaskRegr$new(id = "knn", backend = aqTrain, target = "Ozone")
aqL <- lrn("regr.kknn")
aqMod <- aqL$train(aqT)

I can measure the mean square error of the model predictions doing something like:

prediction <- aqL$predict(aqT)
measure <- msr("regr.mse")
prediction$score(measure)

But how do I incorporate the test data into this? That is, how do I measure the performance of predictions on the test data?

In the previous version of mlr I could do something like; get the predictions using the test dataset and measure the performance of, say, the MSE or Rsquared values like so:

pred <- predict(aqMod, newdata = aqTest)
performance(pred, measures = list(mse, rsq))

Any suggestions as to how I can do this in mlr3?

1
You can do this in much the same way in mlr3, see the relevant chapter of the mlr3 book.Lars Kotthoff
Having read that chapter of the book, it tells me I should do something like: ` prediction = aqL$predict(aqT, row_ids = testRows)` where, testRows are the indices of the rows in the training data. However, this produces an error: Error: DataBackend did not return the queried rows correctly: 33 requested, 17 receivedElectrino
You should also be able to use newdata in the same way as for mlr.Lars Kotthoff
There are a lot of examples in the linked manual. Usually you want to use cross-validation for unbiased performance estimates.pat-s
You either need to pass the complete task and work with row IDs for train/test, or split the data beforehand and work with newdata.Michel

1 Answers

0
votes

You should trying this code

pred <- aqMod$predict_newdata(aqTest)
pred$score(list(msr("regr.mse"),
                   msr("regr.rmse")))