Functions describing the multimodal evaluation pipeline for assessing model performance and performing risk stratification
This document provides an overview of the functions defined in src.utils.eval_utils
. Each function is listed with its signature and docstring.
plot_learning_curve
def plot_learning_curve(losses_path: str = None, output_path="learning_curve.png"):
"""
Plot the learning curve (training and validation loss) from a CSV file.
Args:
losses_path (str): Path to CSV file containing the training and validation loss.
output_path (str): Path to save the learning curve plot.
Returns:
None
"""
plot_roc
def plot_roc(
y_test: np.array,
prob: np.array,
output_path: str = "roc.png",
result_dict: dict = None,
outcome: str = "In-hospital Death",
):
"""
Plot the ROC curve with AUC and 95% CI for a binary classifier.
Args:
y_test (np.array): Ground truth binary labels.
prob (np.array): Predicted probabilities for the positive class.
output_path (str): Path to save the ROC curve plot.
result_dict (dict): Target performance dictionary for storing performance metrics.
outcome (str): Name of the outcome being evaluated.
Returns:
None
"""
plot_pr
def plot_pr(
y_test: np.array,
prob: np.array,
output_path: str = "pr_curve.png",
result_dict: dict = None,
outcome: str = "In-hospital Death",
):
"""
Plot the Precision-Recall curve with AUC and 95% CI for a binary classifier.
Args:
y_test (np.array): Ground truth binary labels.
prob (np.array): Predicted probabilities for the positive class.
output_path (str): Path to save the PR curve plot.
result_dict (dict): Target performance dictionary for storing performance metrics.
outcome (str): Name of the outcome being evaluated.
Returns:
None
"""
plot_calibration_curve
def plot_calibration_curve(
y_test: np.array,
prob: np.array,
output_path: str = "calib_curve.png",
outcome: str = "In-hospital Death",
n_bins: int = 10,
):
"""
Plot the calibration curve for a binary classifier.
Args:
y_test (np.array): Ground truth binary labels (0 or 1).
prob (np.array): Predicted probabilities for the positive class.
output_path (str): Path to save the calibration curve plot.
outcome (str): Name of the outcome being evaluated.
n_bins (int): Number of bins to use for calibration.
Returns:
None
"""
expect_f1
def expect_f1(y_prob: np.array, thres: int) -> float:
"""
Calculate expected F1 score for a given threshold.
Args:
y_prob (np.array): Predicted probabilities.
thres (float): Threshold for binary classification.
Returns:
float: Expected F1 score.
"""
optimal_threshold
def optimal_threshold(y_prob: np.array) -> float:
"""
Calculate the optimal threshold for binary classification based on expected F1 score.
Args:
y_prob (np.array): Predicted probabilities.
Returns:
float: Optimal threshold.
"""
get_roc_performance
def get_roc_performance(y_test: np.array, prob: np.array, verbose: bool = False):
"""
Compute ROC performance summary based on Youden's J statistic for a binary classifier.
Args:
y_test (np.array): Ground truth binary labels.
prob (np.array): Predicted probabilities for the positive class.
verbose (bool): If True, print detailed performance metrics.
Returns:
tuple: (bin_labels, res_dict_roc)
bin_labels (np.array): Binary predictions using Youden's J threshold.
res_dict_roc (dict): ROC statistics and confidence intervals.
"""
get_pr_performance
def get_pr_performance(
y_test: np.array,
prob: np.array,
bin_labels: np.array,
opt_f1: bool = True,
verbose: bool = False,
):
"""
Compute Precision-Recall performance metrics and confidence intervals.
Args:
y_test (np.array): Ground truth binary labels.
prob (np.array): Predicted probabilities for the positive class.
bin_labels (np.array): Binary predictions.
opt_f1 (bool): If True, use optimal F1 threshold.
verbose (bool): If True, print detailed performance metrics.
Returns:
dict: PR statistics and confidence intervals.
"""
get_all_roc_pr_summary
def get_all_roc_pr_summary(
res_dicts: list,
models: list,
colors: list,
output_roc_path: str = "roc_summary.png",
output_pr_path: str = "pr_summary.png",
):
"""
Plot summary ROC and PR curves for multiple models.
Args:
res_dicts (list): List of dictionaries with model results.
models (list): List of model names.
colors (list): List of colors for plotting.
output_roc_path (str): Path to save ROC summary plot.
output_pr_path (str): Path to save PR summary plot.
Returns:
None
"""
rank_prediction_quantiles
def rank_prediction_quantiles(
y_test: np.array,
prob: np.array,
attrs: list,
attr_disp: list,
test_ids: list,
n_bins: int = 10,
outcome: str = "In-hospital Death",
output_path: str = "risk_strat.png",
by_attribute: bool = False,
attr_features: pd.DataFrame = None,
verbose: bool = False,
):
"""
Rank predictions into quantiles and plot risk stratification, optionally stratified by attribute.
Args:
y_test (np.array): Ground truth binary labels.
prob (np.array): Predicted probabilities for the positive class.
attrs (list): List of attribute names for stratification.
attr_disp (list): List of display names for attributes.
test_ids (list): List of subject IDs for test set.
n_bins (int): Number of quantiles.
outcome (str): Name of the outcome being evaluated.
output_path (str): Path to save the risk stratification plot.
by_attribute (bool): If True, stratify by attribute.
attr_features (pd.DataFrame): DataFrame with attribute features.
verbose (bool): If True, print detailed information.
Returns:
dict: Appended patient risk quantiles.
"""