LDA Classification in R

   Linear Discriminant Analysis (LDA) is mainly used to classify multiclass classification problems. The LDA model estimates the mean and variance for each class in a dataset and finds out covariance to discriminate each class. To make a prediction the model estimates the input data matching probability to each class by using Bayes theorem. In this post, we learn how to use LDA model and predict data with R.
   In this tutorial, we use iris dataset as target data, and to fit the model we use lda() and caret's train() functions.

First, we load the required libraries.

library(caret)
data(iris)

Next, we split iris data into train and test parts. Here, test data contains 5 percent of iris dataset.

set.seed(123)
indexes <- createDataPartition(iris$Species, p = .95, list = F)
train <- iris[indexes, ]
test <- iris[-indexes, ]

Fitting model with lda() function

fit_lda <- lda(Species~., data = train)
coef(fit_lda)
                    LD1         LD2
Sepal.Length  0.8678284 -0.06364228
Sepal.Width   1.5001177  2.18697821
Petal.Length -2.2233847 -0.87055421
Petal.Width  -2.7067315  2.77745136

Predicting test data.

pred_lda <- predict(fit_lda, test[,-5])
data.frame(original = test$Species, pred = pred_lda$class)
    original       pred
1     setosa     setosa
2     setosa     setosa
3 versicolor versicolor
4 versicolor versicolor
5  virginica  virginica
6  virginica  virginica


Fitting model with a caret

For caret training, we apply cross-validation.

trCtrl <- trainControl(method = "cv", number = 5)
fit_car <- train(Species~., data=train, method="lda", 
              trControl = trCtrl, metric = "Accuracy")


Predicting test data.

pred_car <- predict(fit_car, test[,-5])
data.frame(original = test$Species, pred = pred_car)
    original       pred
1     setosa     setosa
2     setosa     setosa
3 versicolor versicolor
4 versicolor versicolor
5  virginica  virginica
6  virginica  virginica


That is it! I hope you've found the post useful!


1 comment: