# main.py
# A demonstration of advanced statistical visualization using Seaborn.
#
# Before running, you may need to install seaborn, pandas, and scikit-learn:
# pip install seaborn pandas scikit-learn matplotlib
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
from sklearn.datasets import load_iris
print("--- Starting Advanced Statistical Visualization Demonstration with Seaborn ---")
# --- Section 1: Load the Dataset ---
# We will use the Iris dataset, which is perfect for this kind of visualization
# as it has multiple numerical features and a categorical target.
print("\n--- 1. Loading the Iris Dataset ---")
iris = load_iris()
# Create a pandas DataFrame for easier manipulation with Seaborn
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print("Dataset loaded successfully. First 5 rows:")
print(df.head())
# --- Section 2: Create a Pair Plot ---
# A pair plot (or scatterplot matrix) creates a grid of axes such that each
# variable in the data is shared across the y-axes on a single row and
# the x-axes on a single column.
# The diagonal plots show the distribution of each variable (as a histogram or KDE).
# The off-diagonal plots show the relationship between pairs of variables (as a scatter plot).
print("\n--- 2. Generating a Pair Plot ---")
try:
# sns.pairplot() is the core function.
# The 'hue' parameter colors the data points based on the 'species' column,
# making it easy to see how the different species are clustered.
sns.pairplot(df, hue='species', palette='viridis')
plt.suptitle('Pairwise Relationships in the Iris Dataset', y=1.02) # Add a title above the plot
# Save the plot to a file
plot_filename = 'iris_pairplot.png'
plt.savefig(plot_filename)
print(f"\nPair plot saved as '{plot_filename}'")
plt.show() # Display the plot
plt.close()
except Exception as e:
print(f"An error occurred during visualization: {e}")
# --- Clean up the created image file ---
finally:
print("\n--- Cleaning up created image file ---")
if 'plot_filename' in locals() and os.path.exists(plot_filename):
os.remove(plot_filename)
print(f"Removed '{plot_filename}'")
print("\n--- End of Demonstration ---")