Functions for preprocessing and cleaning extracted data, as well as generating multimodal features
This document provides an overview of the functions defined in src.utils.preprocessing
. Each function is listed with its signature and docstring.
preproc_icd_module
def preproc_icd_module(
diagnoses: pl.DataFrame | pl.LazyFrame,
icd_map_path: str = "../config/icd9to10.txt",
map_code_colname: str = "diagnosis_code",
only_icd10: bool = True,
ltc_dict_path: str = "../outputs/icd10_codes.json",
verbose=True,
use_lazy: bool = False,
) -> pl.DataFrame:
"""
Process a diagnoses dataset with ICD codes, mapping ICD-9 to ICD-10 and generating features for long-term conditions.
Implementation is taken from the MIMIC-IV preprocessing pipeline provided by Gupta et al. (https://github.com/healthylaife/MIMIC-IV-Data-Pipeline/tree/main).
Args:
diagnoses (pl.DataFrame | pl.LazyFrame): Diagnoses data.
icd_map_path (str): Path to ICD-9 to ICD-10 mapping file.
map_code_colname (str): Column name for ICD code in mapping.
only_icd10 (bool): If True, only keep ICD-10 codes.
ltc_dict_path (str): Path to JSON with LTC code groups.
verbose (bool): If True, print summary statistics.
use_lazy (bool): If True, return a LazyFrame.
Returns:
pl.DataFrame or pl.LazyFrame: Processed diagnoses data.
"""
get_ltc_features
def get_ltc_features(
admits_last: pl.DataFrame | pl.LazyFrame,
diagnoses: pl.DataFrame | pl.LazyFrame,
ltc_dict_path: str = "../outputs/icd10_codes.json",
mm_cutoff: int = 1,
cmm_cutoff: int = 3,
verbose=True,
use_lazy: bool = False,
) -> pl.DataFrame:
"""
Generate features for long-term conditions and multimorbidity from ICD-10 diagnoses and custom LTC dictionary.
Args:
admits_last (pl.DataFrame | pl.LazyFrame): Admissions data.
diagnoses (pl.DataFrame | pl.LazyFrame): ICD-10 Diagnoses data.
ltc_dict_path (str): Path to JSON with LTC code groups.
mm_cutoff (int): Threshold for multimorbidity.
cmm_cutoff (int): Threshold for complex multimorbidity.
verbose (bool): If True, print summary statistics.
use_lazy (bool): If True, return a LazyFrame.
Returns:
pl.DataFrame or pl.LazyFrame: Admissions data with long-term condition count features.
"""
transform_sensitive_attributes
def transform_sensitive_attributes(ed_pts: pl.DataFrame) -> pl.DataFrame:
"""
Map sensitive attributes (race, marital status) to predefined categories and types.
Args:
ed_pts (pl.DataFrame): Patient data.
Returns:
pl.DataFrame: Updated patient data.
"""
prepare_medication_features
def prepare_medication_features(
medications: pl.DataFrame | pl.LazyFrame,
admits_last: pl.DataFrame | pl.LazyFrame,
top_n: int = 50,
use_lazy: bool = False,
) -> pl.DataFrame:
"""
Generate count and temporal (days since prescription) features for drug-level medication history.
Args:
medications (pl.DataFrame | pl.LazyFrame): Medication data.
admits_last (pl.DataFrame | pl.LazyFrame): Final hospitalisations data.
top_n (int): Number of top medications to include.
use_lazy (bool): If True, return a LazyFrame.
Returns:
pl.DataFrame or pl.LazyFrame: Admissions data with medication count features.
"""
encode_categorical_features
def encode_categorical_features(ehr_data: pl.DataFrame) -> pl.DataFrame:
"""
Apply one-hot encoding to categorical features in EHR data.
Args:
ehr_data (pl.DataFrame): Static EHR dataset.
Returns:
pl.DataFrame: Transformed EHR data.
"""
extract_lookup_fields
def extract_lookup_fields(
ehr_data: pl.DataFrame,
lookup_list: list = None,
lookup_output_path: str = "../outputs/reference",
) -> pl.DataFrame:
"""
Extract date and summary fields not suitable for training into a separate DataFrame.
Args:
ehr_data (pl.DataFrame): Static EHR dataset.
lookup_list (list): List of columns to extract.
lookup_output_path (str): Directory to save lookup fields.
Returns:
pl.DataFrame: EHR data with lookup fields removed.
"""
remove_correlated_features
def remove_correlated_features(
ehr_data: pl.DataFrame,
feats_to_save: list = None,
threshold: float = 0.9,
method: str = "pearson",
verbose: bool = True,
) -> pl.DataFrame:
"""
Drop highly correlated features from EHR data, keeping specified features.
Args:
ehr_data (pl.DataFrame): Static EHR dataset.
feats_to_save (list): Features to keep.
threshold (float): Correlation threshold.
method (str): Correlation method. Defaults to Pearson's R.
verbose (bool): If True, print summary.
Returns:
pl.DataFrame: EHR data with correlated features removed.
"""
generate_train_val_test_set
def generate_train_val_test_set(
ehr_data: pl.DataFrame,
output_path: str = "../outputs/processed_data",
outcome_col: str = "in_hosp_death",
output_summary_path: str = "../outputs/exp_data",
seed: int = 0,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
cont_cols: list = None,
nn_cols: list = None,
disp_dict: dict = None,
stratify: bool = True,
verbose: bool = True,
) -> dict:
"""
Create train/val/test split from static EHR data and save patient IDs across each split.
Args:
ehr_data (pl.DataFrame): Static EHR dataset.
output_path (str): Directory to save split IDs.
outcome_col (str): Outcome column name.
output_summary_path (str): Directory to save summary.
seed (int): Random seed.
train_ratio (float): Proportion for training set.
val_ratio (float): Proportion for validation set.
test_ratio (float): Proportion for test set.
cont_cols (list): Continuous columns.
nn_cols (list): Non-normal columns.
disp_dict (dict): Display name mapping.
stratify (bool): If True, stratify splits balancing the sets by outcome prevalence, gender and ethnicity.
verbose (bool): If True, print summary.
Returns:
dict: Dictionary with train, val, and test DataFrames.
"""
clean_notes
def clean_notes(notes: pl.DataFrame | pl.LazyFrame) -> pl.DataFrame | pl.LazyFrame:
"""
Clean notes data by removing special characters and extra whitespaces.
Args:
notes (pl.DataFrame | pl.LazyFrame): Notes data.
Returns:
pl.DataFrame or pl.LazyFrame: Cleaned notes data.
"""
process_text_to_embeddings
def process_text_to_embeddings(notes: pl.DataFrame) -> dict:
"""
Generate embeddings using the Bio+Discharge ClinicalBERT model pre-trained on MIMIC-III discharge summaries.
The current setup uses a SpaCy tokenizer mapped to a PyTorch object for GPU support.
Text length is limited to 128 tokens per clinical note, with included padding and truncation where appropriate.
The pre-trained model is provided by Alsentzer et al. (https://huggingface.co/emilyalsentzer/Bio_Discharge_Summary_BERT).
Args:
notes (pl.DataFrame): DataFrame containing notes data.
Returns:
dict: Mapping from subject_id to list of (sentence, embedding) pairs.
"""
clean_labevents
def clean_labevents(labs_data: pl.LazyFrame) -> pl.LazyFrame:
"""
Clean lab events by removing non-integer values and outliers.
Args:
labs_data (pl.LazyFrame): Lab events data.
Returns:
pl.LazyFrame: Cleaned lab events.
"""
add_time_elapsed_to_events
def add_time_elapsed_to_events(
events: pl.DataFrame, starttime: pl.Datetime, remove_charttime: bool = False
) -> pl.DataFrame:
"""
Add a column for time elapsed since a reference start time.
Args:
events (pl.DataFrame): Events table.
starttime (pl.Datetime): Reference start time.
remove_charttime (bool): If True, remove charttime column.
Returns:
pl.DataFrame: Updated events table.
"""
convert_events_to_timeseries
def convert_events_to_timeseries(events: pl.DataFrame) -> pl.DataFrame:
"""
Convert long-form events to wide-form time-series.
Args:
events (pl.DataFrame): Long-form events.
Returns:
pl.DataFrame: Wide-form time-series.
"""
generate_interval_dataset
def generate_interval_dataset(
ehr_static: pl.DataFrame,
ts_data: pl.DataFrame,
ehr_regtime: pl.DataFrame,
vitals_freq: str = "5h",
lab_freq: str = "1h",
min_events: int = None,
max_events: int = None,
impute: str = "value",
include_dyn_mean: bool = False,
no_resample: bool = False,
standardize: bool = False,
max_elapsed: int = None,
vitals_lkup: list = None,
outcomes: list = None,
verbose: bool = True,
) -> dict:
"""
Generate a time-series dataset with set intervals for each event source (vital signs and lab measurements).
Args:
ehr_static (pl.DataFrame): Static EHR data.
ts_data (pl.DataFrame): Time-series data.
ehr_regtime (pl.DataFrame): Lookup dataframe for ED arrival times.
vitals_freq (str): Frequency for vitals resampling.
lab_freq (str): Frequency for labs resampling.
min_events (int): Include only patients with a minimum number of events.
max_events (int): Include only patients with a maximum number of events.
impute (str): Imputation method. Options are "value" (filling with -1), "forward" filling, "backward" filling or "mask" creating a string indicator for missingness.
include_dyn_mean (bool): If True, add dynamic mean features to static dataset.
no_resample (bool): If True, skip resampling.
standardize (bool): If True, standardize data using min-max scaling.
max_elapsed (int): Restrict collected measurements within the set hours from ED arrival.
vitals_lkup (list): List of vital sign features.
outcomes (list): List of outcome columns.
verbose (bool): If True, print summary.
Returns:
dict: Data dictionary and column dictionary.
"""
_prepare_feature_map_and_freq
def _prepare_feature_map_and_freq(
ts_data: pl.DataFrame, vitals_freq: str = "5h", lab_freq: str = "1h"
) -> tuple[dict, dict]:
"""
Prepare a mapping of feature names and frequency for each time-series source.
Args:
ts_data (pl.DataFrame): Time-series data containing a 'linksto' column.
vitals_freq (str): Frequency for vital signs.
lab_freq (str): Frequency for lab measurements.
Returns:
tuple: (feature_map, freq) where feature_map is a dict mapping data source to features,
and freq is a dict mapping data source to frequency string.
"""
_process_patient_events
def _process_patient_events(
pt_events: pl.DataFrame,
feature_map: dict,
freq: dict,
ehr_static: pl.DataFrame,
edregtime: pl.Datetime,
min_events: int = 1,
max_events: int = None,
impute: str = "value",
include_dyn_mean: bool = False,
no_resample: bool = False,
max_elapsed: int = None,
) -> tuple[bool, list[pl.DataFrame]]:
"""
Process time-series events for a single patient, handling missing features, imputation, resampling, and filtering.
Args:
pt_events (pl.DataFrame): Patient's time-series events.
feature_map (dict): Mapping from source to feature names.
freq (dict): Mapping from source to frequency string.
ehr_static (pl.DataFrame): Static EHR data for the patient.
edregtime (pl.Datetime): Lookup dataframe for ED registration time.
min_events (int): Minimum number of measurements required.
max_events (int): Maximum number of measurements required.
impute (str): Imputation method. Options are "value" (filling with -1), "forward" filling, "backward" filling or "mask" creating a string indicator for missingness.
include_dyn_mean (bool): If True, add dynamic mean features.
no_resample (bool): If True, skip resampling.
max_elapsed (int): Restrict collected measurements within the set hours from ED arrival.
Returns:
tuple: (write_data, ts_data_list, skipped_due_to_event_count, skipped_due_to_elapsed_time)
"""
_validate_event_count
def _validate_event_count(
timeseries: pl.DataFrame, min_events: int = 1, max_events: int = 1e6
) -> bool:
"""
Check if the number of events in the timeseries is within the specified range.
Args:
timeseries (pl.DataFrame): Time-series data.
min_events (int): Minimum number of events.
max_events (int): Maximum number of events.
Returns:
bool: True if within range, False otherwise.
"""
_handle_missing_features
def _handle_missing_features(
timeseries: pl.DataFrame, features: list[str] = None
) -> pl.DataFrame:
"""
Add missing columns to the timeseries DataFrame as nulls.
Args:
timeseries (pl.DataFrame): Time-series data.
features (list): List of required feature names.
Returns:
pl.DataFrame: Time-series data with missing columns added as nulls.
"""
_impute_missing_values
def _impute_missing_values(
timeseries: pl.DataFrame, ehr_static: pl.DataFrame, impute: str = "value"
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""
Impute missing values in time-series and static EHR data.
Args:
timeseries (pl.DataFrame): Time-series data.
ehr_static (pl.DataFrame): Static EHR data.
impute (str): Imputation method ("mask", "forward", "backward", "value").
Returns:
tuple: (imputed_timeseries, imputed_ehr_static)
"""
_add_dynamic_mean
def _add_dynamic_mean(
timeseries: pl.DataFrame, ehr_static: pl.DataFrame
) -> pl.DataFrame:
"""
Add mean of dynamic features to the static EHR data.
Args:
timeseries (pl.DataFrame): Time-series data.
ehr_static (pl.DataFrame): Static EHR data.
Returns:
pl.DataFrame: Static EHR data with dynamic means appended.
"""
_resample_timeseries
def _resample_timeseries(timeseries: pl.DataFrame, freq: str = "1h") -> pl.DataFrame:
"""
Resample the time-series data to a specified frequency.
Args:
timeseries (pl.DataFrame): The input time-series data.
freq (str): The frequency for resampling (e.g., "1h").
Returns:
pl.DataFrame: The resampled time-series data.
"""
_standardize_data
def _standardize_data(ts_data: pl.DataFrame) -> pl.DataFrame:
"""
Standardize the 'value' column in the time-series data using min-max scaling.
Args:
ts_data (pl.DataFrame): The input time-series data.
Returns:
pl.DataFrame: Standardized time-series data.
"""
_print_summary
def _print_summary(
n: int = 0,
filter_by_nb_events: int = 0,
missing_event_src: int = 0,
filter_by_elapsed_time: int = 0,
) -> None:
"""
Print a summary of the time-series interval generation process.
Args:
n (int): Number of successfully processed patients.
filter_by_nb_events (int): Number of patients skipped due to event count.
missing_event_src (int): Number of patients skipped due to missing sources.
filter_by_elapsed_time (int): Number of patients skipped due to elapsed time.
Returns:
None
"""