library.pipeline.analysis.neuralNets.neuralNetsPlots module

class library.pipeline.analysis.neuralNets.neuralNetsPlots.NeuralNetsPlots(model_sklearn: object)[source]

Bases: object

plot_per_epoch_progress(metrics: list[str], phase: str, n_cols: int = 2, save_plots: bool = False, save_path: str = None) None[source]

Plots the progress of the feedforward NN per epoch.

Parameters:

metrics: list[str]

The metrics to plot.

phase: str

The phase to plot.

n_cols: int

The number of columns to plot.

save_plots: bool

Whether to save the plots.

save_path: str

The path to save the plots.

Returns:

None