PyTorch pipeline for multimodal learning
This document provides an overview of the classes and functions defined in src.models. Each class or function is listed with its signature and docstring.
LSTM
class LSTM(nn.Module):
"""
General-purpose LSTM module for sequence embeddings.
Args:
input_dim (int): Input feature dimension.
embed_dim (int): Output embedding dimension.
num_layers (int): Number of LSTM layers.
hidden_dim (int): Hidden state dimension.
dropout (float): Dropout rate.
Returns:
torch.Tensor: Embedded output.
"""
Gate
class Gate(nn.Module):
"""
Gated fusion module for combining multiple input embeddings.
Adapted from https://github.com/emnlp-mimic/mimic/blob/main/base.py#L136 inspired by https://arxiv.org/pdf/1908.05787.
Args:
inp1_size (int): Size of first input.
inp2_size (int): Size of second input.
inp3_size (int): Size of third input (optional).
dropout (float): Dropout rate.
Returns:
torch.Tensor: Fused output.
"""
GradientReversalFunction
class GradientReversalFunction(torch.autograd.Function):
"""
Gradient reversal layer for adversarial training.
Used to reverse gradients during backpropagation for adversarial debiasing.
"""
grad_reverse
def grad_reverse(x, lambda_=1.0):
"""
Apply gradient reversal to the input tensor to enable maximisation of the adversarial objective function.
Args:
x (torch.Tensor): Input tensor.
lambda_ (float): Scaling factor for gradient reversal.
Returns:
torch.Tensor: Output tensor with reversed gradients.
"""
MMModel
class MMModel(L.LightningModule):
"""
Multimodal model object for fusion of static, timeseries, and notes data.
Args:
st_input_dim (int): Static input dimension.
st_embed_dim (int): Static embedding dimension.
ts_input_dim (tuple): Tuple of timeseries input dimensions.
ts_embed_dim (int): Timeseries embedding dimension.
nt_input_dim (int): Notes input dimension.
nt_embed_dim (int): Notes embedding dimension.
num_layers (int): Number of LSTM layers.
dropout (float): Dropout rate.
num_ts (int): Number of timeseries modalities.
target_size (int): Output size.
lr (float): Learning rate.
fusion_method (str): Fusion method ("concat" or "mag").
st_first (bool): If True, static features are fused first.
modalities (list): List of modalities to use.
with_packed_sequences (bool): If True, use packed sequences for timeseries.
dataset (MIMIC4Dataset): Pass training dataset if using class weighting, else None.
sensitive_attr_ids (list): Indices of sensitive attribute features from static data for adversarial debiasing.
adv_lambda (float): Strength of adversarial penalty. No penalty is 0. Slight penalty is 0.1-0.2. Strong penalty is >=1.
Returns:
torch.Tensor: Model output.
"""
LitLSTM
class LitLSTM(L.LightningModule):
"""
LSTM model for time-series data only.
Args:
ts_input_dim (int): Timeseries input dimension.
lstm_embed_dim (int): LSTM embedding dimension.
target_size (int): Output size.
lr (float): Learning rate.
with_packed_sequences (bool): If True, use packed sequences.
Returns:
torch.Tensor: Model output.
"""
SaveLossesCallback
class SaveLossesCallback(Callback):
"""
Learner callback to save train/validation losses to a CSV file.
"""