In this post, we'll briefly learn how to classify data with a random forest model in R. We'll use iris dataset to classify with the 'randomForest' function and 'caret' training method. The tutorial covers:
- Preparing data
- Classification with randomForest
- Classification with caret
library(randomForest) library(caret) data(iris)
Preparing data
First, we'll split iris dataset into train and test parts. To split data we'll use a createDataPartition function in a caret package.
indexes = createDataPartition(iris$Species, p = .90, list = FALSE) train = iris[indexes, ] test = iris[-indexes, ]
print(test) Sepal.Length Sepal.Width Petal.Length Petal.Width Species 5 5.0 3.6 1.4 0.2 setosa 6 5.4 3.9 1.7 0.4 setosa 10 4.9 3.1 1.5 0.1 setosa 37 5.5 3.5 1.3 0.2 setosa 50 5.0 3.3 1.4 0.2 setosa 65 5.6 2.9 3.6 1.3 versicolor 68 5.8 2.7 4.1 1.0 versicolor 77 6.8 2.8 4.8 1.4 versicolor 83 5.8 2.7 3.9 1.2 versicolor 95 5.6 2.7 4.2 1.3 versicolor 101 6.3 3.3 6.0 2.5 virginica 110 7.2 3.6 6.1 2.5 virginica 112 6.4 2.7 5.3 1.9 virginica 140 6.9 3.1 5.4 2.1 virginica 147 6.3 2.5 5.0 1.9 virginica
Classification with randomForest
Next, we'll fit the model with train data.
model = randomForest(Species~., data = train) print(model) Call: randomForest(formula = Species ~ ., data = train) Type of random forest: classification Number of trees: 500 No. of variables tried at each split: 2 OOB estimate of error rate: 5.19% Confusion matrix: setosa versicolor virginica class.error setosa 45 0 0 0.00000000 versicolor 0 41 4 0.08888889 virginica 0 3 42 0.06666667
The model is ready and we can predict test data.
pred = predict(model, test)
Finally, we'll check the prediction accuracy.
cm = confusionMatrix(test$Species, pred) print(cm) Confusion Matrix and Statistics Reference Prediction setosa versicolor virginica setosa 5 0 0 versicolor 0 5 0 virginica 0 0 5 Overall Statistics Accuracy : 1 95% CI : (0.782, 1) No Information Rate : 0.3333 P-Value [Acc > NIR] : 6.969e-08 Kappa : 1 Mcnemar's Test P-Value : NA Statistics by Class: Class: setosa Class: versicolor Class: virginica Sensitivity 1.0000 1.0000 1.0000 Specificity 1.0000 1.0000 1.0000 Pos Pred Value 1.0000 1.0000 1.0000 Neg Pred Value 1.0000 1.0000 1.0000 Prevalence 0.3333 0.3333 0.3333 Detection Rate 0.3333 0.3333 0.3333 Detection Prevalence 0.3333 0.3333 0.3333 Balanced Accuracy 1.0000 1.0000 1.0000
Classification with caret
The 'caret' package provides a 'train' method to train the model with random forest method. First, we'll define train control, then fit the model with train data. Here, the method "rf" defines a random forest algorithm.
tc = trainControl(method = "repeatedcv", number = 10, repeats = 3) cmodel = train(Species~., data=train, method="rf", trControl=tc) print(cmodel) Random Forest 135 samples 4 predictor 3 classes: 'setosa', 'versicolor', 'virginica' No pre-processing Resampling: Cross-Validated (10 fold, repeated 3 times) Summary of sample sizes: 123, 120, 121, 121, 121, 120, ... Resampling results across tuning parameters: mtry Accuracy Kappa 2 0.9458791 0.9188639 3 0.9482601 0.9224537 4 0.9388950 0.9085940 Accuracy was used to select the optimal model using the largest value. The final value used for the model was mtry = 3.
Now, we can predict the test data.
cpred = predict(cmodel, test)
Finally, we'll check the prediction accuracy.
cm = confusionMatrix(test$Species, cpred) print(cm) Confusion Matrix and Statistics Reference Prediction setosa versicolor virginica setosa 5 0 0 versicolor 0 5 0 virginica 0 0 5 Overall Statistics Accuracy : 1 95% CI : (0.782, 1) No Information Rate : 0.3333 P-Value [Acc > NIR] : 6.969e-08 Kappa : 1 Mcnemar's Test P-Value : NA Statistics by Class: Class: setosa Class: versicolor Class: virginica Sensitivity 1.0000 1.0000 1.0000 Specificity 1.0000 1.0000 1.0000 Pos Pred Value 1.0000 1.0000 1.0000 Neg Pred Value 1.0000 1.0000 1.0000 Prevalence 0.3333 0.3333 0.3333 Detection Rate 0.3333 0.3333 0.3333 Detection Prevalence 0.3333 0.3333 0.3333 Balanced Accuracy 1.0000 1.0000 1.0000
In this post, we've briefly learned how to classify data with random forest model with a 'randomForest' and 'caret' package methods.
The full source code is listed below.
library(randomForest)
library(caret)
data(iris)
set.seed(123)
indexes = createDataPartition(iris$Species, p = .90, list = FALSE)
train = iris[indexes, ]
test = iris[-indexes, ]
# randomForest method
model = randomForest(Species~., data = train)
print(model)
pred = predict(model, test)
cm = confusionMatrix(test$Species, pred)
print(cm)
# caret method
tc = trainControl(method = "repeatedcv", number = 10, repeats = 3)
cmodel = train(Species~., data=train, method="rf", trControl=tc)
print(cmodel)
cpred = predict(cmodel, test)
cm = confusionMatrix(test$Species, cpred)
print(cm)
No comments:
Post a Comment