# main.py
# A demonstration of the Gaussian Naive Bayes classification algorithm.
#
# Before running, you may need to install scikit-learn and pandas:
# pip install scikit-learn pandas
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
print("--- Starting Gaussian Naive Bayes Classification Demonstration ---")
# --- Section 1: Load and Prepare the Dataset ---
# We will use the Iris dataset. Gaussian Naive Bayes works well with
# continuous features like the measurements in this dataset.
print("\n--- 1. Loading the Iris Dataset ---")
iris = load_iris()
X = iris.data # The features
y = iris.target # The target classes
# For clarity, let's see the feature and target names
print(f"Features: {iris.feature_names}")
print(f"Target Classes: {iris.target_names}")
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
print(f"\nData split into {len(X_train)} training samples and {len(X_test)} testing samples.")
# --- Section 2: Train the Gaussian Naive Bayes Model ---
# The core of the program is creating an instance of the GaussianNB classifier
# and fitting it to our training data.
print("\n--- 2. Training the Gaussian Naive Bayes Model ---")
gnb = GaussianNB()
gnb.fit(X_train, y_train)
print("Model training complete.")
# --- Section 3: Make Predictions and Evaluate the Model ---
# Now we use the trained model to make predictions on the unseen test data
# and evaluate how well it performed.
print("\n--- 3. Evaluating the Model ---")
y_pred = gnb.predict(X_test)
# Calculate the accuracy of the model
accuracy = accuracy_score(y_test, y_pred)
print(f"\nModel Accuracy: {accuracy:.4f}")
# Print a detailed classification report
print("\nClassification Report:")
# This report shows key metrics like precision, recall, and F1-score for each class.
print(classification_report(y_test, y_pred, target_names=iris.target_names))
print("\n--- End of Demonstration ---")