Gaussian Naive Bayes (GaussianNB) is a variant of the Naive Bayes algorithm used for classification tasks. Unlike MultinomialNB, which is designed for discrete features (like word counts), GaussianNB is used when the input features are continuous numerical values that are assumed to follow a Gaussian (or normal) distribution.
Licensed by Google
Gaussian Naive Bayes: The Core Concept
The algorithm still operates on the principles of Bayes' Theorem and the "naive" assumption of feature independence. However, the way it calculates the probability of a feature's value, given a class, is different.
Instead of counting frequencies, GaussianNB calculates the mean and standard deviation of each feature for each class from the training data. When it sees a new data point, it uses the Gaussian probability density function (the classic "bell curve" formula) to calculate the probability of that feature value occurring within each class.
How it Works:
1. Calculate Class Priors: It determines the overall probability of each class (e.g., P(Class A)).
2. Calculate Feature Statistics: For each feature and for each class, it calculates the mean (μ) and standard deviation (σ).
3. Make a Prediction: For a new data point, it plugs the feature values into the Gaussian probability formula for each class. It then multiplies these probabilities together (along with the class prior) to get the final probability for each class. The class with the highest probability is the prediction.
Detailed Code Example in Python
This example uses the classic Iris dataset, which has continuous features (sepal and petal measurements), making it a perfect use case for GaussianNB.
# --- 1. Import Necessary Libraries ---
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.datasets import load_iris
# --- 2. Load and Prepare the Dataset ---
# The Iris dataset is a classic dataset in machine learning.
# It contains 3 classes of iris plants with 4 continuous features each.
iris = load_iris()
X = iris.data
y = iris.target
# Split the data for training and testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# --- 3. Create and Train the Gaussian Naive Bayes Model ---
# Initialize the classifier. Unlike MultinomialNB, this model is designed
# for continuous features.
model = GaussianNB()
# Train the model. During this step, the model calculates the mean and
# standard deviation of each of the 4 features for each of the 3 iris classes.
model.fit(X_train, y_train)
print("--- Model Training Complete ---")
# You can inspect the learned parameters (mean and variance)
print(f"\nLearned class priors: {model.class_prior_}")
# print(f"Learned means for each feature per class:\n{model.theta_}")
# --- 4. Make Predictions and Evaluate the Model ---
# Make predictions on the test data.
y_pred = model.predict(X_test)
# Evaluate the model's performance
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred, target_names=iris.target_names)
print(f"\nModel Accuracy: {accuracy * 100:.2f}%")
print("\n--- Classification Report ---")
print(class_report)
# --- 5. Visualize the Confusion Matrix ---
# A confusion matrix is a great way to see how well the model performed
# for each class.
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix for GaussianNB on Iris Dataset')
plt.show()
# --- 6. Predict on a New, Unseen Flower ---
# Let's create a new flower with some measurements:
# [sepal length, sepal width, petal length, petal width]
new_flower = [[5.1, 3.5, 1.4, 0.2]] # These are typical measurements for a Setosa
# Make a prediction
prediction_index = model.predict(new_flower)
predicted_class_name = iris.target_names[prediction_index[0]]
# Get the probabilities for each class
probabilities = model.predict_proba(new_flower)
print("\n--- Prediction for a New Flower ---")
print(f"Measurements: {new_flower[0]}")
print(f"==> Predicted Species: '{predicted_class_name.upper()}'")
print("\nPrediction Probabilities:")
for i, class_name in enumerate(iris.target_names):
print(f" - {class_name}: {probabilities[0][i]:.4%}")