Skip to content

PyTorch pipeline for multimodal learning

This document provides an overview of the classes and functions defined in src.models. Each class or function is listed with its signature and docstring.


LSTM

class LSTM(nn.Module):
    """
    General-purpose LSTM module for sequence embeddings.

    Args:
        input_dim (int): Input feature dimension.
        embed_dim (int): Output embedding dimension.
        num_layers (int): Number of LSTM layers.
        hidden_dim (int): Hidden state dimension.
        dropout (float): Dropout rate.

    Returns:
        torch.Tensor: Embedded output.
    """

Gate

class Gate(nn.Module):
    """
    Gated fusion module for combining multiple input embeddings.
    Adapted from https://github.com/emnlp-mimic/mimic/blob/main/base.py#L136 inspired by https://arxiv.org/pdf/1908.05787.
    Args:
        inp1_size (int): Size of first input.
        inp2_size (int): Size of second input.
        inp3_size (int): Size of third input (optional).
        dropout (float): Dropout rate.

    Returns:
        torch.Tensor: Fused output.
    """

GradientReversalFunction

class GradientReversalFunction(torch.autograd.Function):
    """
    Gradient reversal layer for adversarial training.

    Used to reverse gradients during backpropagation for adversarial debiasing.
    """

grad_reverse

def grad_reverse(x, lambda_=1.0):
    """
    Apply gradient reversal to the input tensor to enable maximisation of the adversarial objective function.

    Args:
        x (torch.Tensor): Input tensor.
        lambda_ (float): Scaling factor for gradient reversal.

    Returns:
        torch.Tensor: Output tensor with reversed gradients.
    """

MMModel

class MMModel(L.LightningModule):
    """
    Multimodal model object for fusion of static, timeseries, and notes data.

    Args:
        st_input_dim (int): Static input dimension.
        st_embed_dim (int): Static embedding dimension.
        ts_input_dim (tuple): Tuple of timeseries input dimensions.
        ts_embed_dim (int): Timeseries embedding dimension.
        nt_input_dim (int): Notes input dimension.
        nt_embed_dim (int): Notes embedding dimension.
        num_layers (int): Number of LSTM layers.
        dropout (float): Dropout rate.
        num_ts (int): Number of timeseries modalities.
        target_size (int): Output size.
        lr (float): Learning rate.
        fusion_method (str): Fusion method ("concat" or "mag").
        st_first (bool): If True, static features are fused first.
        modalities (list): List of modalities to use.
        with_packed_sequences (bool): If True, use packed sequences for timeseries.
        dataset (MIMIC4Dataset): Pass training dataset if using class weighting, else None.
        sensitive_attr_ids (list): Indices of sensitive attribute features from static data for adversarial debiasing.
        adv_lambda (float): Strength of adversarial penalty. No penalty is 0. Slight penalty is 0.1-0.2. Strong penalty is >=1.

    Returns:
        torch.Tensor: Model output.
    """

LitLSTM

class LitLSTM(L.LightningModule):
    """
    LSTM model for time-series data only.

    Args:
        ts_input_dim (int): Timeseries input dimension.
        lstm_embed_dim (int): LSTM embedding dimension.
        target_size (int): Output size.
        lr (float): Learning rate.
        with_packed_sequences (bool): If True, use packed sequences.

    Returns:
        torch.Tensor: Model output.
    """

SaveLossesCallback

class SaveLossesCallback(Callback):
    """
    Learner callback to save train/validation losses to a CSV file.
    """