Classification Example with Naive Bayes Model in R

    Based on Bayes Theorem, the Naive Bayes model is a supervised classification algorithm and it is commonly used to solve classification problems in machine learning. The model calculates the probability and conditional probability of each class based on input data and classifies each element according to assessment. In Gaussian Naive Bayes model, the values of each class are distributed in a form of a Gaussian distribution. 

    In this tutorial, you'll briefly learn how to implement a Naive Bayes model in R by using naiveBayes() function of the 'e1071' package. The tutorial covers:

  1. Preparing data
  2. Fitting the model and prediction
  3. Source code listing

    We'll start by loading the required packages. Here, we use caret package helps us to prepare data and assess the prediction.

 
library(e1071) 
library(caret) 
  


Preparing data

  We use Iris dataset as a target classification data in this tutorial. First, we'll load and split the dataset into the train and test parts.

  
data(iris)
 
set.seed(124)
 
indexes = createDataPartition(iris$Species, p = .9, list = F)
train = iris[indexes, ]
test = iris[-indexes, ]
   


Fitting the model and prediction

Next, we'll define the Naive Bayes model and fit it on training data

 
library(e1071) 
library(caret) 
 
nb = naiveBayes(Species~., data = train)
print(nb)

Naive Bayes Classifier for Discrete Predictors

Call:
naiveBayes.default(x = X, y = Y, laplace = laplace)

A-priori probabilities:
Y
    setosa versicolor  virginica 
 0.3333333  0.3333333  0.3333333 

Conditional probabilities:
            Sepal.Length
Y                [,1]      [,2]
  setosa     5.015556 0.3649132
  versicolor 5.913333 0.5224940
  virginica  6.640000 0.6035501

            Sepal.Width
Y                [,1]      [,2]
  setosa     3.440000 0.3927757
  versicolor 2.757778 0.3250796
  virginica  2.975556 0.3269387

            Petal.Length
Y                [,1]      [,2]
  setosa     1.468889 0.1648997
  versicolor 4.257778 0.4745758
  virginica  5.584444 0.5468736

            Petal.Width
Y                 [,1]      [,2]
  setosa     0.2422222 0.1033284
  versicolor 1.3333333 0.2011332
  virginica  2.0000000 0.2602446 
  

The model is ready and now we can predict test data


pred = predict(nb, test, type="class") 
  

We'll check the accuracy of prediction with confusion matrix function.

 
cm = confusionMatrix(test$Species, pred)
print(cm)
 
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa          5          0         0
  versicolor      0          5         0
  virginica       0          1         4

Overall Statistics
                                          
               Accuracy : 0.9333          
                 95% CI : (0.6805, 0.9983)
    No Information Rate : 0.4             
    P-Value [Acc > NIR] : 2.523e-05       
                                          
                  Kappa : 0.9             
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: setosa Class: versicolor Class: virginica
Sensitivity                 1.0000            0.8333           1.0000
Specificity                 1.0000            1.0000           0.9091
Pos Pred Value              1.0000            1.0000           0.8000
Neg Pred Value              1.0000            0.9000           1.0000
Prevalence                  0.3333            0.4000           0.2667
Detection Rate              0.3333            0.3333           0.2667
Detection Prevalence        0.3333            0.3333           0.3333
Balanced Accuracy           1.0000            0.9167           0.9545 
 

    In this tutorial, we've briefly learned how to use a 'e1071' package's naiveBayes() function to classify data. The full source code is listed below.

Source code listing

 
library(e1071)
library(caret)
 
data(iris)
 
set.seed(124)
 
indexes = createDataPartition(iris$Species, p = .9, list = F)
train = iris[indexes, ]
test = iris[-indexes, ]
 
nb = naiveBayes(Species~., data = train)
print(nb)
 
pred = predict(nb, test, type="class")
 
cm = confusionMatrix(test$Species, pred)
print(cm) 
  

No comments:

Post a Comment