# main.py
# A demonstration of data visualization using Matplotlib and Pandas.
#
# Before running, you may need to install pandas, matplotlib, and openpyxl:
# pip install pandas matplotlib openpyxl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
print("--- Starting Data Visualization Demonstration ---")
# --- Setup: Recreate the cleaned and merged DataFrame from the previous script ---
# In a real-world scenario, you would load this from a file. Here, we'll
# quickly recreate it for a self-contained example.
# 1. Create initial data with missing values
csv_data = {
'EmployeeID': ['E01', 'E02', 'E03', 'E04', 'E05', 'E06', 'E07'],
'Name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve', 'Frank', 'Grace'],
'Age': [34, 45, 28, np.nan, 40, 52, 30],
'Department': ['HR', 'Engineering', 'Sales', 'Engineering', 'HR', 'Sales', 'Marketing'],
'Salary': [70000, 95000, 65000, 80000, np.nan, 120000, 68000]
}
df_employees = pd.DataFrame(csv_data)
# 2. Clean the data
df_employees['Age'].fillna(df_employees['Age'].mean(), inplace=True)
df_employees['Salary'].fillna(df_employees.groupby('Department')['Salary'].transform('median'), inplace=True)
# 3. Create and merge location data
excel_data = {
'Department': ['HR', 'Engineering', 'Sales', 'Marketing'],
'Location': ['New York', 'London', 'Tokyo', 'Paris']
}
df_locations = pd.DataFrame(excel_data)
df = pd.merge(df_employees, df_locations, on='Department', how='left')
print("\n[Final DataFrame ready for visualization]")
print(df.head())
# --- Section 1: Bar Chart ---
# To compare categorical data.
print("\n--- 1. Creating a Bar Chart (Employees per Department) ---")
try:
# Get counts of employees in each department
department_counts = df['Department'].value_counts()
plt.figure(figsize=(10, 6)) # Create a figure and axes
department_counts.plot(kind='bar', color='skyblue')
plt.title('Number of Employees per Department')
plt.xlabel('Department')
plt.ylabel('Number of Employees')
plt.xticks(rotation=45) # Rotate x-axis labels for better readability
plt.tight_layout() # Adjust layout to not cut off labels
# Save the plot to a file
bar_chart_filename = 'department_barchart.png'
plt.savefig(bar_chart_filename)
print(f"Bar chart saved as '{bar_chart_filename}'")
plt.show() # Display the plot
plt.close() # Close the figure to free up memory
except Exception as e:
print(f"An error occurred while creating the bar chart: {e}")
# --- Section 2: Histogram ---
# To show the distribution of a numerical variable.
print("\n--- 2. Creating a Histogram (Salary Distribution) ---")
try:
plt.figure(figsize=(10, 6))
plt.hist(df['Salary'], bins=5, color='lightgreen', edgecolor='black')
plt.title('Salary Distribution')
plt.xlabel('Salary')
plt.ylabel('Frequency')
plt.tight_layout()
histogram_filename = 'salary_histogram.png'
plt.savefig(histogram_filename)
print(f"Histogram saved as '{histogram_filename}'")
plt.show() # Display the plot
plt.close()
except Exception as e:
print(f"An error occurred while creating the histogram: {e}")
# --- Section 3: Pie Chart ---
# To show the proportion of each category.
print("\n--- 3. Creating a Pie Chart (Employees by Location) ---")
try:
location_counts = df['Location'].value_counts()
plt.figure(figsize=(8, 8))
plt.pie(location_counts, labels=location_counts.index, autopct='%1.1f%%', startangle=140,
colors=['gold', 'yellowgreen', 'lightcoral', 'lightskyblue'])
plt.title('Proportion of Employees by Location')
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
pie_chart_filename = 'location_piechart.png'
plt.savefig(pie_chart_filename)
print(f"Pie chart saved as '{pie_chart_filename}'")
plt.show() # Display the plot
plt.close()
except Exception as e:
print(f"An error occurred while creating the pie chart: {e}")
# --- Section 4: Scatter Plot ---
# To explore the relationship between two numerical variables.
print("\n--- 4. Creating a Scatter Plot (Age vs. Salary) ---")
try:
plt.figure(figsize=(10, 6))
plt.scatter(df['Age'], df['Salary'], alpha=0.7, color='purple')
plt.title('Age vs. Salary')
plt.xlabel('Age')
plt.ylabel('Salary')
plt.grid(True) # Add a grid for easier reading
plt.tight_layout()
scatter_plot_filename = 'age_salary_scatterplot.png'
plt.savefig(scatter_plot_filename)
print(f"Scatter plot saved as '{scatter_plot_filename}'")
plt.show() # Display the plot
plt.close()
except Exception as e:
print(f"An error occurred while creating the scatter plot: {e}")
# --- Clean up the created image files ---
finally:
print("\n--- Cleaning up created image files ---")
for filename in ['department_barchart.png', 'salary_histogram.png', 'location_piechart.png', 'age_salary_scatterplot.png']:
if os.path.exists(filename):
os.remove(filename)
print(f"Removed '{filename}'")
print("\n--- End of Demonstration ---")