Source code for library.phases.phases_implementation.modelling.shallow.model_definition.model_base
from abc import ABC, abstractmethod
import time
from library.phases.phases_implementation.dataset.dataset import Dataset
import matplotlib.pyplot as plt
import pandas as pd
from library.phases.phases_implementation.modelling.shallow.model_definition.model_states.model_state import PreTuningState, PostTuningState, InTuningState
[docs]
class Model(ABC):
def __init__(self, modelName: str, model_sklearn: object, model_type: str, results_header: list[str], dataset: Dataset):
"""
This is the base class for all the model objects. It initializes the differeent tuning states and defines the fitting and predicitng methods for those states
"""
assert model_type in ["classical", "neural_network", "stacking"], "Model type must be one of the following: classical, neural_network"
assert model_sklearn is not None, "Model sklearn must be provided"
self.dataset = dataset
self.modelName = modelName
self.model_sklearn = model_sklearn
self.model_type = model_type
# Remove from header the duplicate metrics
cleaned_header = []
for col in results_header:
if col.endswith("_val"):
cleaned_header.append(col.split("_")[0])
elif col.endswith("_test"):
continue
else:
cleaned_header.append(col)
self.results_header = cleaned_header + ["predictions_val", "predictions_train", "predictions_test", "model_sklearn"]
self.tuning_states = {
"pre": PreTuningState(model_sklearn, modelName, model_type, dataset, self.results_header),
"in": InTuningState(model_sklearn, modelName, model_type, dataset, self.results_header),
"post": PostTuningState(model_sklearn, modelName, model_type, dataset, self.results_header)
}
self.optimizer_type = None
[docs]
@abstractmethod
def evaluate(self, modelName: str):
pass
[docs]
def fit(self, modelName: str, current_phase: str, **kwargs):
assert current_phase in self.tuning_states.keys(), "Current phase must be one of the tuning states"
print(f"=> Fitting {modelName} model")
self.tuning_states[current_phase].fit(**kwargs)
[docs]
def predict(self, modelName: str, current_phase: str):
assert current_phase in self.tuning_states.keys(), "Current phase must be one of the tuning states"
print(f"=> Predicting {modelName} model")
self.tuning_states[current_phase].predict()