# main.py
# A demonstration of Linear and Polynomial Regression using scikit-learn.
#
# Before running, you may need to install scikit-learn and matplotlib:
# pip install scikit-learn matplotlib pandas
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
print("--- Starting Regression Models Demonstration ---")
# --- Setup: Generate Sample Non-Linear Data ---
# We'll create data that follows a quadratic curve (y = ax^2 + bx + c)
# and add some random noise to make it more realistic.
# Set a seed for reproducibility
np.random.seed(42)
# Generate 100 data points
X = 2 - 3 * np.random.normal(0, 1, 100)
y = X - 2 * (X ** 2) + np.random.normal(-3, 3, 100)
# Reshape X to be a 2D array, as required by scikit-learn
X = X[:, np.newaxis]
print("\n[Sample data generated]")
# --- Section 1: Simple Linear Regression ---
# This model will try to fit a straight line to the non-linear data.
print("\n--- 1. Training a Simple Linear Regression Model ---")
try:
linear_reg = LinearRegression()
linear_reg.fit(X, y)
y_linear_pred = linear_reg.predict(X)
print("Linear Regression model trained successfully.")
except Exception as e:
print(f"An error occurred during linear regression: {e}")
# --- Section 2: Polynomial Regression ---
# This model will fit a curve (in this case, a 2nd-degree polynomial) to the data.
print("\n--- 2. Training a Polynomial Regression Model ---")
try:
# We create a pipeline to first transform the features to a higher degree,
# then apply linear regression.
# degree=2 means we will fit a quadratic curve (y = ax^2 + bx + c)
polynomial_reg = make_pipeline(PolynomialFeatures(degree=2), LinearRegression())
polynomial_reg.fit(X, y)
# To plot a smooth curve, we need to sort the X values before predicting.
X_sorted = np.sort(X, axis=0)
y_poly_pred = polynomial_reg.predict(X_sorted)
print("Polynomial Regression model trained successfully.")
except Exception as e:
print(f"An error occurred during polynomial regression: {e}")
# --- Section 3: Visualization ---
# Plotting the original data points and both regression lines to compare them.
print("\n--- 3. Visualizing the Results ---")
try:
plt.figure(figsize=(12, 7))
# Original data points
plt.scatter(X, y, color='blue', s=20, label='Actual Data')
# Linear regression line
plt.plot(X, y_linear_pred, color='red', linewidth=2, label='Linear Regression Fit')
# Polynomial regression curve
plt.plot(X_sorted, y_poly_pred, color='green', linewidth=2, label='Polynomial Regression Fit (Degree 2)')
plt.title('Linear vs. Polynomial Regression')
plt.xlabel('X (Independent Variable)')
plt.ylabel('y (Dependent Variable)')
plt.legend()
plt.grid(True)
# Save the plot to a file
plot_filename = 'regression_comparison.png'
plt.savefig(plot_filename)
print(f"Comparison plot saved as '{plot_filename}'")
plt.show() # Display the plot
plt.close()
except Exception as e:
print(f"An error occurred during visualization: {e}")
# --- Clean up the created image file ---
finally:
if 'plot_filename' in locals() and os.path.exists(plot_filename):
os.remove(plot_filename)
print(f"\n--- Clean up: Removed '{plot_filename}' ---")
print("\n--- End of Demonstration ---")