Source code for library.phases.phases_implementation.EDA.EDA

import math
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from library.phases.phases_implementation.dataset.dataset import Dataset

from library.utils.miscellaneous.save_or_store_plot import save_or_store_plot


[docs] class EDA: """ We will be using 'composition' desing pattern to create plots from the dataframe object that is an instance of the Dataset class This design pattern allows for two classes to be able to share data (e.g: dataset object) """ def __init__(self, dataset: Dataset) -> None: self.dataset = dataset
[docs] def plot_correlation_matrix(self, size: str = "small", numerical_df: pd.DataFrame = None, title: str = "", save_plots: bool = False, save_path: str = "", **kwargs) -> None: """ Plots the correlation matrix of the dataframe Parameters ---------- size : str The size of the plot. Taken on ["s", "m", "l", "auto"] Returns ------- None """ corr = numerical_df.corr() mask = np.triu(np.ones_like(corr, dtype=bool)) # avoid redundancy if size == "s": f, ax = plt.subplots(figsize=(5, 3)) elif size == "m": f, ax = plt.subplots(figsize=(10, 6)) elif size == "l": f, ax = plt.subplots(figsize=(20, 15)) elif size == "auto": f, ax = plt.subplots() cmap = sns.diverging_palette(230, 20, as_cmap=True) vmin, vmax = corr.min().min(), corr.max().max() sns.heatmap(corr, mask=mask, cmap=cmap, center=0, square=True, linewidths=.5, cbar_kws={"shrink": .8}, vmin=vmin, vmax=vmax, xticklabels=corr.columns, yticklabels=corr.index, **kwargs) plt.title(f"{title}") save_or_store_plot(f, save_plots, save_path + "/feature_selection/manual/multicollinearity", f"{title}.png")
[docs] def plot_categorical_distributions(self, features: list[str], n_cols: int = 2): """ Plots the distributions of specified categorical features as count plots. Parameters ---------- features : list of str List of categorical feature names to plot. n_cols : int, optional Number of columns for the subplot grid. Default is 2. Returns ------- None """ n_rows = len(features) fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5*n_rows)) axes = axes.flatten() for i, feature in enumerate(features): ax_count = axes[i] sns.countplot(x=feature, data=self.dataset.df, ax=ax_count) ax_count.set_title(f"{feature} - Count Plot") for j in range(len(features), len(axes)): axes[j].set_visible(False) plt.tight_layout() plt.show()
[docs] def count_boxplot_descriptive(self, features: list[str]): """ Plots the distribution histogram, boxplot, and descriptive statistics summary for each specified feature. Parameters ---------- features : list of str List of feature names to analyze and plot. Returns ------- None """ n_rows = len(features) n_cols = 3 fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5*n_rows)) for i, feature in enumerate(features): ax_hist = axes[i, 0] ax_box = axes[i, 1] ax_text = axes[i, 2] self.dataset.df[feature].hist(ax=ax_hist) ax_hist.set_title(f"{feature} - Distribution") self.dataset.df[feature].plot(kind="box", ax=ax_box) ax_box.set_title(f"{feature} - Boxplot") # Generate summary statistics text using describe() summary_text = self.dataset.df[feature].describe().to_string() ax_text.axis('off') # Place the text; using a monospaced font helps align the numbers ax_text.text(0.5, 0.5, summary_text, ha='center', va='center', fontfamily='monospace', fontsize=10) ax_text.set_title(f"{feature} - Summary Statistics") plt.tight_layout() plt.show()
[docs] def lineplot_bivariate(self, features: list[str], target: str, n_cols: int = 3): """ Plots the line plot of a feature against the target with maximized x-axis ticks and stretched figure size. """ n_rows = math.ceil(len(features) / n_cols) fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 8 * n_rows)) # Increased figure size if n_cols != 1: axes = axes.flatten() # Flatten the array to 1D for i, feature in enumerate(features): sns.lineplot(x=feature, y=target, data=self.dataset.df, ax=axes[i]) axes[i].set_title(f"{feature} vs {target}") axes[i].set_xlabel(feature) axes[i].set_ylabel(target) # Increase number of x-axis ticks x_values = self.dataset.df[feature] n_ticks = min(len(x_values.unique()), 20) # Cap at 20 ticks to avoid overcrowding axes[i].set_xticks(x_values.unique()) axes[i].tick_params(axis='x', rotation=45) # Rotate labels for better readability # Hide any unused subplots for j in range(len(features), len(axes)): axes[j].set_visible(False) else: sns.lineplot(x=features[0], y=target, data=self.dataset.df) plt.title(f"{features[0]} vs {target}") plt.xlabel(features[0]) plt.ylabel(target) plt.tight_layout() plt.show()
[docs] def scatterplot_bivariate(self, features: list[str], target: str, n_cols: int = 3): """ Plots line plots for each specified feature against the target variable. The plots have an expanded figure size and enhanced x-axis ticks for better readability. If multiple features are provided, plots are arranged in a grid with the specified number of columns. Parameters ---------- features : list of str List of feature names to plot on the x-axis. target : str The target variable name to plot on the y-axis. n_cols : int, optional (default=3) Number of columns in the subplot grid. Returns ------- None """ n_rows = math.ceil(len(features) / n_cols) fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 8 * n_rows)) # Increased figure size if n_cols != 1: axes = axes.flatten() # Flatten the array to 1D for i, feature in enumerate(features): sns.scatterplot(x=feature, y=target, data=self.dataset.df, ax=axes[i]) axes[i].set_title(f"{feature} vs {target}") axes[i].set_xlabel(feature) axes[i].set_ylabel(target) # Increase number of x-axis ticks x_values = self.dataset.df[feature] n_ticks = min(len(x_values.unique()), 20) # Cap at 20 ticks to avoid overcrowding axes[i].set_xticks(x_values.unique()) axes[i].tick_params(axis='x', rotation=45) # Rotate labels for better readability # Hide any unused subplots for j in range(len(features), len(axes)): axes[j].set_visible(False) else: sns.scatterplot(x=features[0], y=target, data=self.dataset.df) plt.title(f"{features[0]} vs {target}") plt.xlabel(features[0]) plt.ylabel(target) plt.tight_layout() plt.show()
[docs] def barplot_bivariate(self, features: list[str], target: str, n_cols: int = 3): """ Plots bar plots for each specified feature against the target variable. The function adjusts the figure size for better visibility and optimizes the x-axis ticks, including handling interval-type features by converting them to strings. Plots are arranged in a grid layout based on the specified number of columns. Parameters ---------- features : list of str List of feature names to plot on the x-axis. target : str The target variable name to plot on the y-axis. n_cols : int, optional (default=3) Number of columns in the subplot grid. Returns ------- None """ n_rows = math.ceil(len(features) / n_cols) fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 8 * n_rows)) # Increased figure size if n_cols != 1: axes = axes.flatten() # Flatten the array to 1D for i, feature in enumerate(features): # Convert interval data to strings if needed if pd.api.types.is_interval_dtype(self.dataset.df[feature]): x_values = self.dataset.df[feature].astype(str) else: x_values = self.dataset.df[feature] sns.barplot(x=x_values, y=target, data=self.dataset.df, ax=axes[i]) axes[i].set_title(f"{feature} vs {target}") axes[i].set_xlabel(feature) axes[i].set_ylabel(target) # Increase number of x-axis ticks n_ticks = min(len(x_values.unique()), 20) # Cap at 20 ticks to avoid overcrowding axes[i].set_xticks(range(len(x_values.unique()))) axes[i].set_xticklabels(x_values.unique()) axes[i].tick_params(axis='x', rotation=45) # Rotate labels for better readability # Hide any unused subplots for j in range(len(features), len(axes)): axes[j].set_visible(False) else: # Convert interval data to strings if needed if pd.api.types.is_interval_dtype(self.dataset.df[features[0]]): x_values = self.dataset.df[features[0]].astype(str) else: x_values = self.dataset.df[features[0]] sns.barplot(x=x_values, y=target, data=self.dataset.df) plt.title(f"{features[0]} vs {target}") plt.xlabel(features[0]) plt.ylabel(target) plt.tight_layout() plt.show()