# main.py
# A demonstration of advanced model evaluation for a classification algorithm.
#
# Before running, you may need to install scikit-learn, pandas, seaborn, and matplotlib:
# pip install scikit-learn pandas seaborn matplotlib
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
print("--- Starting Advanced Model Evaluation Demonstration ---")
# --- Section 1: Load and Prepare the Dataset ---
# We will use the Iris dataset.
print("\n--- 1. Loading the Iris Dataset ---")
iris = load_iris()
X = iris.data # The features
y = iris.target # The target classes
# For clarity, let's see the feature and target names
print(f"Features: {iris.feature_names}")
print(f"Target Classes: {iris.target_names}")
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
print(f"\nData split into {len(X_train)} training samples and {len(X_test)} testing samples.")
# --- Section 2: Train the Gaussian Naive Bayes Model ---
print("\n--- 2. Training the Gaussian Naive Bayes Model ---")
gnb = GaussianNB()
gnb.fit(X_train, y_train)
print("Model training complete.")
# --- Section 3: Advanced Model Evaluation ---
# Now we use the trained model to make predictions and evaluate its performance
# using more detailed metrics than just accuracy.
print("\n--- 3. Evaluating the Model ---")
y_pred = gnb.predict(X_test)
# a) Accuracy Score
accuracy = accuracy_score(y_test, y_pred)
print(f"\nModel Accuracy: {accuracy:.4f} (A good starting point, but can be misleading)")
# b) Classification Report
# This report is crucial, especially for imbalanced datasets, as it shows
# precision, recall, and F1-score for each individual class.
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
# c) Confusion Matrix
# A confusion matrix gives a detailed breakdown of correct and incorrect predictions
# for each class.
print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
print(cm)
# d) Visualizing the Confusion Matrix
# A heatmap makes the confusion matrix much easier to interpret.
print("\nVisualizing the Confusion Matrix as a Heatmap...")
try:
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=iris.target_names,
yticklabels=iris.target_names)
plt.title('Confusion Matrix')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
# Save the plot to a file
plot_filename = 'confusion_matrix.png'
plt.savefig(plot_filename)
print(f"Confusion matrix plot saved as '{plot_filename}'")
plt.show()
plt.close()
except Exception as e:
print(f"An error occurred during visualization: {e}")
# --- Clean up the created image file ---
finally:
if 'plot_filename' in locals() and os.path.exists(plot_filename):
os.remove(plot_filename)
print(f"\n--- Clean up: Removed '{plot_filename}' ---")
print("\n--- End of Demonstration ---")