Skip to content

Functions describing the fairness evaluation pipeline with added bootstrapping

This document provides an overview of the functions defined in src.utils.fairness_utils. Each function is listed with its signature and docstring.


plot_bar_metric_frame

def plot_bar_metric_frame(
    metrics: dict,
    y_test: np.ndarray,
    y_hat: np.ndarray,
    attr_df: pl.DataFrame,
    attribute: str,
    save_path: str,
    figsize=None,
    nrows=2,
    ncols=2,
    seed=0,
):
    """
    Plot an error bar chart for the given metric frame using Fairlearn.

    Args:
        metrics (dict): Dictionary of metric functions to compute.
        y_test (np.ndarray): Ground truth labels.
        y_hat (np.ndarray): Predicted labels.
        attr_df (pl.DataFrame): Sensitive attribute DataFrame.
        attribute (str): Name of the sensitive attribute.
        save_path (str): Path to save the plot.
        figsize (tuple, optional): Figure size.
        nrows (int): Number of subplot rows.
        ncols (int): Number of subplot columns.
        seed (int): Random seed.

    Returns:
        None
    """

bias_corrected_ci

def bias_corrected_ci(bootstrap_samples, observed_value):
    """
    Calculate bias-corrected and accelerated (BCa) confidence intervals for bootstrap samples.

    Args:
        bootstrap_samples (array-like): Bootstrap sample values.
        observed_value (float): Observed value for bias correction.

    Returns:
        tuple: (lower_bound, upper_bound) confidence interval.
    """

get_bootstrapped_fairness_measures

def get_bootstrapped_fairness_measures(
    y_test: np.ndarray,
    y_hat: np.ndarray,
    attr_pf: pl.DataFrame,
    n_boot: int = 1000,
    seed: int = 0,
    skip_ci: bool = False,
    verbose: bool = False,
) -> tuple:
    """
    Compute bootstrapped fairness measures (Demographic Parity, Equalized Odds, Equal Opportunity)
    and their confidence intervals.

    Args:
        y_test (np.ndarray): Ground truth labels.
        y_hat (np.ndarray): Predicted labels.
        attr_pf (pl.DataFrame): Sensitive attribute DataFrame.
        n_boot (int): Number of bootstrap samples.
        seed (int): Random seed.
        skip_ci (bool): If True, skip CI calculation.
        verbose (bool): If True, print additional information.

    Returns:
        tuple: (dpr_full, eor_full, eop_full) where each contains (mean, lower_CI, upper_CI).
    """

plot_fairness_by_age

def plot_fairness_by_age(
    aq_dict: dict,
    age_labels: list,
    out_path: str,
    attributes: list,
    attribute_labels: list,
    figsize: tuple = (11, 8),
    measure: str = "DPR",
    measure_label: str = "Demographic Parity",
):
    """
    Plot fairness measures by age group for multiple sensitive attributes.

    Args:
        aq_dict (dict): Dictionary containing fairness metrics by age group.
        age_labels (list): List of age group labels.
        out_path (str): Path to save the plot.
        attributes (list): List of sensitive attribute names.
        attribute_labels (list): List of attribute display names.
        figsize (tuple): Figure size.
        measure (str): Key for the fairness measure to plot.
        measure_label (str): Display label for the fairness measure.

    Returns:
        None
    """

get_fairness_summary

def get_fairness_summary(
    res_all: dict,
    models: list,
    colors: list,
    attribute_labels: list,
    figsize: tuple = (13, 8),
    nrows: int = 2,
    ncols: int = 2,
    outcome: str = "Extended Stay",
    output_path: str = "fair_full_across_models.png",
):
    """
    Plot grouped barplots for fairness metrics (DPR, EQO, EOP) with 95% CI for each model.

    Args:
        res_all (dict): Dictionary containing fairness metrics for each model.
        models (list): List of model names.
        colors (list): List of colors for each model.
        attribute_labels (list): List of attribute display names.
        figsize (tuple): Figure size.
        nrows (int): Number of subplot rows.
        ncols (int): Number of subplot columns.
        outcome (str): Name of the outcome.
        output_path (str): Path to save the plot.

    Returns:
        None
    """