Source code for library.phases.phases_implementation.data_preprocessing.class_imbalance
from library.phases.phases_implementation.dataset.dataset import Dataset
import matplotlib.pyplot as plt
import seaborn as sns
from imblearn.over_sampling import SMOTE, ADASYN
from sklearn.preprocessing import LabelEncoder
from library.utils.miscellaneous.save_or_store_plot import save_or_store_plot
[docs]
class ClassImbalance:
def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
[docs]
def class_imbalance(self, method: str = "SMOTE", save_plots: bool = False, save_path: str = None) -> str:
"""
Balances classes via SMOTE and optionally plots the distributions
before and after resampling.
Parameters
----------
method : str
The method to use for balancing classes.
save_plots : bool
Whether to save plots of class counts before/after SMOTE
save_path : str
The path to save the plots
Returns
-------
str
Summary of the balancing operation
"""
# --- Input validation ---
if not isinstance(save_plots, bool):
raise TypeError("Parameter 'save_plots' must be a boolean.")
# --- Attribute checks ---
for attr in ['X_train', 'y_train']:
if not hasattr(self.dataset, attr):
raise AttributeError(f"The dataset is missing the attribute '{attr}'.")
try:
counts_before = self.dataset.y_train.value_counts().sort_index()
except Exception as e:
raise RuntimeError(f"Could not compute class counts: {e}")
if counts_before.empty or len(counts_before) < 2:
raise ValueError("SMOTE requires at least two classes with non-zero samples.")
try:
self.imbalance_ratio = counts_before.min() / counts_before.max()
except ZeroDivisionError:
raise ValueError("Class count contains zero, cannot compute imbalance ratio.")
# --- Plot before resampling ---
fig, ax = plt.subplots(figsize=(6, 4), nrows=1, ncols=2)
ax = ax.flatten()
if save_plots:
try:
sns.barplot(
x=counts_before.index.astype(str),
y=counts_before.values,
ax=ax[0]
)
ax[0].set_title(f"Before {method} (imbalance ratio {self.imbalance_ratio:.2f}:1)")
ax[0].set_xlabel("Class")
ax[0].set_ylabel("Count")
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=45, ha="right")
except Exception as e:
raise RuntimeError(f"An error occurred while plotting pre-SMOTE: {e}")
# --- Apply SMOTE ---
try:
if method == "SMOTE":
smote = SMOTE(random_state=42)
elif method == "ADASYN":
smote = ADASYN(random_state=42)
X_res, y_res = smote.fit_resample(self.dataset.X_train, self.dataset.y_train)
self.dataset.X_train = X_res
self.dataset.y_train = y_res
except Exception as e:
raise RuntimeError(f"An error occurred during SMOTE resampling: {e}")
# --- Plot after resampling ---
if save_plots:
try:
counts_after = self.dataset.y_train.value_counts().sort_index()
sns.barplot(
x=counts_after.index.astype(str),
y=counts_after.values,
ax=ax[1]
)
ax[1].set_title(f"After {method} (balanced 1:1)")
ax[1].set_xlabel("Class")
ax[1].set_ylabel("Count")
ax[1].set_xticklabels(ax[1].get_xticklabels(), rotation=45, ha="right")
plt.tight_layout(w_pad=3)
save_or_store_plot(fig, save_plots, save_path + "/class_imbalance", f"class_imbalance.png")
except Exception as e:
raise RuntimeError(f"An error occurred while plotting post-SMOTE: {e}")
return (
f"Successfully balanced classes via SMOTE. "
f"Started with a {self.imbalance_ratio:.2f}:1 ratio; now 1:1."
)