ctgan_sampler
CTGANConditionalSampler
Manages conditional vector construction and conditioned data sampling for CTGAN.
At each training step CTGAN: 1. Randomly selects one categorical column. 2. Samples a category from that column's empirical distribution. 3. Builds a one-hot condition vector spanning all categorical columns. 4. Resamples real training rows that have the selected category active.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Tensor
|
Full training tensor of shape |
required |
categorical_groups
|
list[list[int]]
|
List of index groups corresponding to OHE categorical columns in the transformed space (one group per original categorical column). |
required |
Source code in src/nhssynth/modules/model/common/ctgan_sampler.py
sample_condvec(batch_size)
Sample a batch of condition vectors.
Returns:
| Name | Type | Description |
|---|---|---|
cond |
Tensor
|
|
mask |
Tensor
|
|
col_idxs |
ndarray
|
|
cat_idxs |
ndarray
|
|
Source code in src/nhssynth/modules/model/common/ctgan_sampler.py
sample_data_conditioned(batch_size, col_idxs, cat_idxs)
Sample real training rows conditioned on each (col, category) pair.
For each sample, picks a random row from the training data where the specified category is active in the specified column. Falls back to a random row if no such rows exist.
Returns:
| Type | Description |
|---|---|
Tensor
|
|
Source code in src/nhssynth/modules/model/common/ctgan_sampler.py
extract_categorical_groups(multi_column_indices, columns)
Identify which groups in multi_column_indices are OHE categorical columns
(as opposed to GMM component columns, whose names end in _c<digit>).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
multi_column_indices
|
list[list[int]]
|
From |
required |
columns
|
Column names ( |
required |
Returns:
| Type | Description |
|---|---|
list[list[int]]
|
Subset of |