QDA Classification with R

   Quadratic Discriminant Analysis (QDA) is a classification algorithm and it is used in machine learning and statistics problems. QDA is an extension of Linear Discriminant Analysis (LDA). Unlike LDA, QDA considers each class has its own variance or covariance matrix rather than to have a common one.
   In this tutorial, we'll learn how to classify data with QDA method in R. The tutorial covers:
  1. Preparing data
  2. Prediction with a qda() function
  3. Prediction with caret train() with a qda method.
We'll start by loading the required libraries and dataset.

library(MASS)
library(caret)
data(iris) 


Preparing data

Next, we'll split iris dataset into the train and test parts.

indexes = createDataPartition(iris$Species, p = .9, list = F)
train = iris[indexes, ]
test = iris[-indexes, ]


Prediction with qda() function

We'll train data with a qda() function.

fit_qda = qda(Species~., data=train)
coef(fit_qda)
                    LD1         LD2
Sepal.Length  0.8678284 -0.06364228
Sepal.Width   1.5001177  2.18697821
Petal.Length -2.2233847 -0.87055421
Petal.Width  -2.7067315  2.77745136

Next, we'll predict test data and check the accuracy with the confusion matrix.

pred_qda = predict(fit_qda,test[,-5])
data.frame(original=test$Species, pred=pred_qda$class)
     original       pred
1      setosa     setosa
2      setosa     setosa
3      setosa     setosa
4      setosa     setosa
5      setosa     setosa
6  versicolor versicolor
7  versicolor versicolor
8  versicolor versicolor
9  versicolor versicolor
10 versicolor versicolor
11  virginica  virginica
12  virginica  virginica
13  virginica  virginica
14  virginica  virginica
15  virginica  virginica
 
confusionMatrix(test$Species, pred_qda$class)
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa          5          0         0
  versicolor      0          5         0
  virginica       0          0         5

Overall Statistics
                                    
               Accuracy : 1         
                 95% CI : (0.782, 1)
    No Information Rate : 0.3333    
    P-Value [Acc > NIR] : 6.969e-08 
                                    
                  Kappa : 1         
 Mcnemar's Test P-Value : NA        

Statistics by Class:

                     Class: setosa Class: versicolor Class: virginica
Sensitivity                 1.0000            1.0000           1.0000
Specificity                 1.0000            1.0000           1.0000
Pos Pred Value              1.0000            1.0000           1.0000
Neg Pred Value              1.0000            1.0000           1.0000
Prevalence                  0.3333            0.3333           0.3333
Detection Rate              0.3333            0.3333           0.3333
Detection Prevalence        0.3333            0.3333           0.3333
Balanced Accuracy           1.0000            1.0000           1.0000


Prediction with caret train() with a qda method.

In a caret training method, we'll implement cross-validation and fit the model.

trCtrl = trainControl(method = "cv", number = 5)
fit_car = train(Species~., data=train, method="qda", 
                trControl = trCtrl, metric = "Accuracy")


We'll predict test data and check the accuracy with a confusion matrix.

pred_car = predict(fit_car, test[,-5])
data.frame(original=test$Species, pred=pred_car)
     original       pred
1      setosa     setosa
2      setosa     setosa
3      setosa     setosa
4      setosa     setosa
5      setosa     setosa
6  versicolor versicolor
7  versicolor versicolor
8  versicolor versicolor
9  versicolor versicolor
10 versicolor versicolor
11  virginica  virginica
12  virginica  virginica
13  virginica  virginica
14  virginica  virginica
15  virginica  virginica
 
confusionMatrix(test$Species, pred_car)
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa          5          0         0
  versicolor      0          5         0
  virginica       0          0         5

Overall Statistics
                                    
               Accuracy : 1         
                 95% CI : (0.782, 1)
    No Information Rate : 0.3333    
    P-Value [Acc > NIR] : 6.969e-08 
                                    
                  Kappa : 1         
 Mcnemar's Test P-Value : NA        

Statistics by Class:

                     Class: setosa Class: versicolor Class: virginica
Sensitivity                 1.0000            1.0000           1.0000
Specificity                 1.0000            1.0000           1.0000
Pos Pred Value              1.0000            1.0000           1.0000
Neg Pred Value              1.0000            1.0000           1.0000
Prevalence                  0.3333            0.3333           0.3333
Detection Rate              0.3333            0.3333           0.3333
Detection Prevalence        0.3333            0.3333           0.3333
Balanced Accuracy           1.0000            1.0000           1.0000

   In this post, we've briefly learned how to classify data with QDA method in R. The full source code is listed below.


Source code listing
library(MASS) library(caret) data(iris) set.seed(123) indexes = createDataPartition(iris$Species, p = .9, list = F) train = iris[indexes, ] test = iris[-indexes, ] fit_qda = qda(Species~., data=train) coef(fit_qda) summary(fit_qda) pred_qda = predict(fit_qda, test[,-5]) data.frame(original=test$Species, pred=pred_qda$class) confusionMatrix(test$Species, pred_qda$class) # caret training trCtrl = trainControl(method = "cv", number = 5) fit_car = train(Species~., data=train, method="qda", trControl = trCtrl, metric = "Accuracy") pred_car = predict(fit_car, test[,-5]) data.frame(original=test$Species, pred=pred_car) confusionMatrix(test$Species, pred_car)

1 comment:

  1. Yo fitted QDA model but displayed coefficients of LDA model and again used your LDA model for prediction but you are saying it is QDA model.

    ReplyDelete