We'll start by loading the packages.
library(caret) library(caretEnsemble)
To implement stacking with caretEnsemble package, we can only use the binary type of dataset. It looks like the multiclass type of data is not implemented yet. If your dataset contains multiclass data, you may get below error.
Error in check_caretList_model_types(list_of_models) :
Not yet implemented for multiclass problems
Thus, we'll prepare a binary class dataset from iris data and split it into train and test parts.
ir = iris[iris$Species!="setosa", ] ir$Species = factor(ir$Species) indexes = createDataPartition(ir$Species, p = .90, list = FALSE) train = ir[indexes, ] test = ir[-indexes, ]
Next, we'll select the models to use in a stack.
methods = c("gbm", 'rpart','rf')
We'll prepare the trainControl parameter.
tc = trainControl(method = "repeatedcv", number = 5, repeats = 3, classProbs = TRUE)
Next, we'll create a list of train models with createList function().
models = caretList(Species~., data = train, trControl = tc, methodList = methods)
We'll check the performance statistics of models with a resampling() function.
output = resamples(models) summary(output) Call: summary.resamples(object = output) Models: gbm, rpart, rf Number of resamples: 15 Accuracy Min. 1st Qu. Median Mean 3rd Qu. Max. NA's gbm 0.8333333 0.9444444 0.9444444 0.9296296 0.9444444 1 0 rpart 0.7777778 0.8333333 0.8888889 0.8962963 0.9444444 1 0 rf 0.8333333 0.9166667 0.9444444 0.9333333 0.9444444 1 0 Kappa Min. 1st Qu. Median Mean 3rd Qu. Max. NA's gbm 0.6666667 0.8888889 0.8888889 0.8592593 0.8888889 1 0 rpart 0.5555556 0.6666667 0.7777778 0.7925926 0.8888889 1 0 rf 0.6666667 0.8333333 0.8888889 0.8666667 0.8888889 1 0
dotplot(output)
Next, we'll combine the models via stacking with a generalized linear model (glm). The caretStack() function helps us to implement it.
stack = caretStack(models, method="glm", trControl = tc) stack A glm ensemble of 2 base models: gbm, rpart, rf Ensemble results: Generalized Linear Model 270 samples 3 predictor 2 classes: 'versicolor', 'virginica' No pre-processing Resampling: Cross-Validated (5 fold, repeated 3 times) Summary of sample sizes: 216, 216, 216, 216, 216, 216, ... Resampling results: Accuracy Kappa 0.9222222 0.8444444 summary(stack) Call: NULL Deviance Residuals: Min 1Q Median 3Q Max -2.75683 -0.25108 -0.04183 0.23252 2.38491 Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) -3.2571 0.4409 -7.387 1.50e-13 *** gbm -7.1642 2.9530 -2.426 0.0153 * rpart -3.5524 1.4016 -2.535 0.0113 * rf 17.5519 3.7890 4.632 3.62e-06 *** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for binomial family taken to be 1) Null deviance: 374.30 on 269 degrees of freedom Residual deviance: 103.13 on 266 degrees of freedom AIC: 111.13 Number of Fisher Scoring iterations: 6
Finally, we'll predict the test data and check the result.
pred = predict(stack, test[,1:4]) cm = confusionMatrix(test$Species, pred) print(cm) Confusion Matrix and Statistics Reference Prediction versicolor virginica versicolor 5 0 virginica 0 5 Accuracy : 1 95% CI : (0.6915, 1) No Information Rate : 0.5 P-Value [Acc > NIR] : 0.0009766 Kappa : 1 Mcnemar's Test P-Value : NA Sensitivity : 1.0 Specificity : 1.0 Pos Pred Value : 1.0 Neg Pred Value : 1.0 Prevalence : 0.5 Detection Rate : 0.5 Detection Prevalence : 0.5 Balanced Accuracy : 1.0 'Positive' Class : versicolor
In this post, we've briefly learned how to implement a stacking method for the classification problem in R.
The full source code is listed below.
library(caret)
library(caretEnsemble)
ir = iris[iris$Species!="setosa", ]
ir$Species = factor(ir$Species)
indexes = createDataPartition(ir$Species, p = .90, list = FALSE)
train = ir[indexes, ]
test = ir[-indexes, ]
methods = c("gbm", 'rpart','rf')
tc = trainControl(method = "repeatedcv", number = 5,
repeats = 3, classProbs = TRUE)
models = caretList(Species~., data = train,
trControl = tc, methodList = methods)
output = resamples(models)
summary(output)
dotplot(output)
stack = caretStack(models, method="glm", trControl = tc)
summary(stack)
pred = predict(stack, test[,1:4])
cm = confusionMatrix(test$Species, pred)
print(cm)
No comments:
Post a Comment