Skip to content

metatransformer

MetaTransformer

The metatransformer is responsible for transforming input dataset into a format that can be used by the model module, and for transforming this module's output back to the original format of the input dataset.

Parameters:

Name Type Description Default
dataset DataFrame

The raw input DataFrame.

required
metadata Optional[MetaData]

Optionally, a MetaData object containing the metadata for the dataset. If this is not provided it will be inferred from the dataset.

None
missingness_strategy Optional[str]

The missingness strategy to use. Defaults to augmenting missing values in the data, see the missingness strategies for more information.

'augment'
impute_value Optional[Any]

Only used when missingness_strategy is set to 'impute'. The value to use when imputing missing values in the data.

None

After calling MetaTransformer.apply(), the following attributes and methods will be available:

Attributes:

Name Type Description
typed_dataset DataFrame

The dataset with the dtypes applied.

post_missingness_strategy_dataset DataFrame

The dataset with the missingness strategies applied.

transformed_dataset DataFrame

The transformed dataset.

single_column_indices list[int]

The indices of the columns that were transformed into a single column.

multi_column_indices list[list[int]]

The indices of the columns that were transformed into multiple columns.

Methods:

  • get_typed_dataset(): Returns the typed dataset.
  • get_prepared_dataset(): Returns the dataset with the missingness strategies applied.
  • get_transformed_dataset(): Returns the transformed dataset.
  • get_multi_and_single_column_indices(): Returns the indices of the columns that were transformed into one or multiple column(s).
  • get_sdv_metadata(): Returns the metadata in the correct format for SDMetrics.
  • save_metadata(): Saves the metadata to a file.
  • save_constraint_graphs(): Saves the constraint graphs to a file.

Note that mt.apply is a helper function that runs mt.apply_dtypes, mt.apply_missingness_strategy and mt.transform in sequence. This is the recommended way to use the MetaTransformer to ensure that it is fully instantiated for use downstream.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
class MetaTransformer:
    """
    The metatransformer is responsible for transforming input dataset into a format that can be used by the `model` module, and for transforming
    this module's output back to the original format of the input dataset.

    Args:
        dataset: The raw input DataFrame.
        metadata: Optionally, a [`MetaData`][nhssynth.modules.dataloader.metadata.MetaData] object containing the metadata for the dataset. If this is not provided it will be inferred from the dataset.
        missingness_strategy: The missingness strategy to use. Defaults to augmenting missing values in the data, see [the missingness strategies][nhssynth.modules.dataloader.missingness] for more information.
        impute_value: Only used when `missingness_strategy` is set to 'impute'. The value to use when imputing missing values in the data.

    After calling `MetaTransformer.apply()`, the following attributes and methods will be available:

    Attributes:
        typed_dataset (pd.DataFrame): The dataset with the dtypes applied.
        post_missingness_strategy_dataset (pd.DataFrame): The dataset with the missingness strategies applied.
        transformed_dataset (pd.DataFrame): The transformed dataset.
        single_column_indices (list[int]): The indices of the columns that were transformed into a single column.
        multi_column_indices (list[list[int]]): The indices of the columns that were transformed into multiple columns.

    **Methods:**

    - `get_typed_dataset()`: Returns the typed dataset.
    - `get_prepared_dataset()`: Returns the dataset with the missingness strategies applied.
    - `get_transformed_dataset()`: Returns the transformed dataset.
    - `get_multi_and_single_column_indices()`: Returns the indices of the columns that were transformed into one or multiple column(s).
    - `get_sdv_metadata()`: Returns the metadata in the correct format for SDMetrics.
    - `save_metadata()`: Saves the metadata to a file.
    - `save_constraint_graphs()`: Saves the constraint graphs to a file.

    Note that `mt.apply` is a helper function that runs `mt.apply_dtypes`, `mt.apply_missingness_strategy` and `mt.transform` in sequence.
    This is the recommended way to use the MetaTransformer to ensure that it is fully instantiated for use downstream.
    """

    def __init__(
        self,
        dataset: pd.DataFrame,
        metadata: Optional[MetaData] = None,
        missingness_strategy: Optional[str] = "augment",
        impute_value: Optional[Any] = None,
    ):
        self._raw_dataset: pd.DataFrame = dataset
        self._metadata: MetaData = metadata or MetaData(dataset)
        if missingness_strategy == "impute":
            assert (
                impute_value is not None
            ), "`impute_value` of the `MetaTransformer` must be specified (via the --impute flag) when using the imputation missingness strategy"
            self._impute_value = impute_value
        self._missingness_strategy = MISSINGNESS_STRATEGIES[missingness_strategy]

    @classmethod
    def from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:
        """
        Instantiates a MetaTransformer from a metadata file via a provided path.

        Args:
            dataset: The raw input DataFrame.
            metadata_path: The path to the metadata file.

        Returns:
            A MetaTransformer object.
        """
        return cls(dataset, MetaData.from_path(dataset, metadata_path), **kwargs)

    @classmethod
    def from_dict(cls, dataset: pd.DataFrame, metadata: dict, **kwargs) -> Self:
        """
        Instantiates a MetaTransformer from a metadata dictionary.

        Args:
            dataset: The raw input DataFrame.
            metadata: A dictionary of raw metadata.

        Returns:
            A MetaTransformer object.
        """
        return cls(dataset, MetaData(dataset, metadata), **kwargs)

    def drop_columns(self) -> None:
        """
        Drops columns from the dataset that are not in the `MetaData`.
        """
        self._raw_dataset = self._raw_dataset[self._metadata.columns]

    def _apply_rounding_scheme(self, working_column: pd.Series, rounding_scheme: float) -> pd.Series:
        """
        A rounding scheme takes the form of the smallest value that should be rounded to 0, i.e. 0.01 for 2dp.
        We first round to the nearest multiple in the standard way, through dividing, rounding and then multiplying.
        However, this can lead to floating point errors, so we then round to the number of decimal places required by the rounding scheme.

        e.g. `np.round(0.15 / 0.1) * 0.1` will erroneously return 0.1.

        Args:
            working_column: The column to apply the rounding scheme to.
            rounding_scheme: The rounding scheme to apply.

        Returns:
            The column with the rounding scheme applied.
        """
        # Check if column is numeric before applying rounding
        if not pd.api.types.is_numeric_dtype(working_column):
            # Try to coerce to numeric first
            working_column = pd.to_numeric(working_column, errors="coerce")
            if not pd.api.types.is_numeric_dtype(working_column):
                # If still not numeric, skip rounding (likely datetime or categorical)
                from tqdm import tqdm

                tqdm.write(f"Warning: Skipping rounding for non-numeric column {working_column.name}")
                return working_column

        working_column = np.round(working_column / rounding_scheme) * rounding_scheme
        return working_column.round(max(0, int(np.ceil(np.log10(1 / rounding_scheme)))))

    def _apply_dtype(
        self,
        working_column: pd.Series,
        column_metadata: MetaData.ColumnMetaData,
    ) -> pd.Series:
        """
        Given a `working_column`, the dtype specified in the `column_metadata` is applied to it.
         - Datetime columns are floored, and their format is inferred.
         - Rounding schemes are applied to numeric columns if specified.
         - Columns with missing values have their dtype converted to the pandas equivalent to allow for NA values.

        Args:
            working_column: The column to apply the dtype to.
            column_metadata: The metadata for the column.

        Returns:
            The column with the dtype applied.
        """
        dtype = column_metadata.dtype
        try:
            if dtype.kind == "M":
                working_column = pd.to_datetime(
                    working_column,
                    format=column_metadata.datetime_config.get("format"),
                    errors="coerce",
                )
                if column_metadata.datetime_config.get("floor"):
                    working_column = working_column.dt.floor(column_metadata.datetime_config.get("floor"))
                    column_metadata.datetime_config["format"] = column_metadata._infer_datetime_format(working_column)
                return working_column
            else:
                if hasattr(column_metadata, "rounding_scheme") and column_metadata.rounding_scheme is not None:
                    working_column = self._apply_rounding_scheme(working_column, column_metadata.rounding_scheme)
                # If there are missing values in the column, we need to use the pandas equivalent of the dtype to allow for NA values
                if working_column.isnull().any() and dtype.kind in ["i", "u", "f"]:
                    return working_column.astype(dtype.name.capitalize())
                else:
                    return working_column.astype(dtype)
        except ValueError:
            raise ValueError(f"{sys.exc_info()[1]}\nError applying dtype '{dtype}' to column '{working_column.name}'")

    def apply_dtypes(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Applies dtypes from the metadata to `dataset`.

        Returns:
            The dataset with the dtypes applied.
        """
        working_data = data.copy()
        for column_metadata in self._metadata:
            working_data[column_metadata.name] = self._apply_dtype(working_data[column_metadata.name], column_metadata)
        return working_data

    def apply_missingness_strategy(self) -> pd.DataFrame:
        """
        Resolves missingness in the dataset via the `MetaTransformer`'s global missingness strategy or
        column-wise missingness strategies. In the case of the `AugmentMissingnessStrategy`, the missingness
        is not resolved, instead a new column / value is added for later transformation.

        Returns:
            The dataset with the missingness strategies applied.
        """
        working_data = self.typed_dataset.copy()
        for column_metadata in self._metadata:
            if not column_metadata.missingness_strategy:
                column_metadata.missingness_strategy = (
                    self._missingness_strategy(self._impute_value)
                    if hasattr(self, "_impute_value")
                    else self._missingness_strategy()
                )
            if not working_data[column_metadata.name].isnull().any():
                continue
            working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)
        return working_data

    def apply_constraints(self) -> pd.DataFrame:
        working_data = self.post_missingness_strategy_dataset.copy()
        for constraint in self._metadata.constraints:
            working_data = constraint.transform(working_data)
        return working_data

    def repair_constraints(
        self, df: pd.DataFrame, *, mode: str = "reflect", rng=None, n_retries: int = 0
    ) -> pd.DataFrame:
        """
        Enforce constraints on a *decoded* DataFrame using
        self._metadata.constraints.minimal_constraints.

        Supports:
        - Numeric constant constraints (e.g., x > 0, x in (0, 100))
        - Column reference constraints (e.g., x8 > x10)

        Args:
            df: DataFrame to repair
            mode: Repair strategy ('reflect', 'resample', 'clamp')
            rng: Random number generator
            n_retries: Number of retries (unused for now)

        Returns:
            DataFrame with constraints enforced
        """
        import numpy as np
        import pandas as pd

        if rng is None:
            rng = np.random.default_rng()

        constraints = getattr(getattr(self, "_metadata", None), "constraints", None)
        if constraints is None:
            return df
        constraints_iterable = getattr(constraints, "minimal_constraints", None)
        if constraints_iterable is None:
            return df

        repaired = df.copy()
        violations_fixed = 0

        # Import ComboConstraint to check instance
        from nhssynth.modules.dataloader.constraints import ConstraintGraph

        # Convert to list to avoid consuming iterator
        constraints_list = list(constraints_iterable)

        # Iterate through minimal constraints and repair violations
        for constraint in constraints_list:
            # Skip combo constraints for now (they require special handling)
            if isinstance(constraint, ConstraintGraph.ComboConstraint):
                continue

            base = constraint.base
            operator = constraint.operator
            reference = constraint.reference
            is_column = constraint.reference_is_column

            # Skip if base column doesn't exist in dataframe
            if base not in repaired.columns:
                continue

            # Get base values
            base_vals = pd.to_numeric(repaired[base], errors="coerce")

            # Handle column reference constraints (e.g., x8 > x10)
            if is_column:
                if reference not in repaired.columns:
                    continue
                ref_vals = pd.to_numeric(repaired[reference], errors="coerce")
            else:
                # Handle numeric constant constraints (e.g., x > 0)
                try:
                    ref_vals = float(reference)
                except ValueError:
                    # Could be datetime, skip for now
                    continue

            # Identify violations
            if operator == ">":
                violations = base_vals <= ref_vals
            elif operator == ">=":
                violations = base_vals < ref_vals
            elif operator == "<":
                violations = base_vals >= ref_vals
            elif operator == "<=":
                violations = base_vals > ref_vals
            else:
                continue  # Unknown operator

            # Only repair finite violations (skip NaN)
            violations = violations & np.isfinite(base_vals)

            n_viol = violations.sum()
            if n_viol == 0:
                continue

            violations_fixed += n_viol

            # Repair strategy
            if operator in [">", ">="]:
                # Base must be greater than reference
                if is_column:
                    # Add small margin above reference (5% of reference value or 0.1, whichever is larger)
                    margin = np.maximum(0.1, 0.05 * np.abs(ref_vals[violations]))
                    if operator == ">":
                        repaired.loc[violations, base] = ref_vals[violations] + margin
                    else:  # >=
                        repaired.loc[violations, base] = ref_vals[violations]
                else:
                    # Numeric constant
                    margin = max(0.1, 0.05 * abs(ref_vals))
                    if operator == ">":
                        repaired.loc[violations, base] = ref_vals + margin
                    else:  # >=
                        repaired.loc[violations, base] = ref_vals

            elif operator in ["<", "<="]:
                # Base must be less than reference
                if is_column:
                    margin = np.maximum(0.1, 0.05 * np.abs(ref_vals[violations]))
                    if operator == "<":
                        repaired.loc[violations, base] = ref_vals[violations] - margin
                    else:  # <=
                        repaired.loc[violations, base] = ref_vals[violations]
                else:
                    margin = max(0.1, 0.05 * abs(ref_vals))
                    if operator == "<":
                        repaired.loc[violations, base] = ref_vals - margin
                    else:  # <=
                        repaired.loc[violations, base] = ref_vals

        if DEBUG_VERBOSE and violations_fixed > 0:
            tqdm.write(f"[repair_constraints] Fixed {violations_fixed} constraint violations")

        return repaired

    def _get_missingness_carrier(self, column_metadata: MetaData.ColumnMetaData) -> Union[pd.Series, Any]:
        """
        In the case of the `AugmentMissingnessStrategy`, a `missingness_carrier` has been determined for each column.
        For continuous columns this is an indicator column for the presence of NaN values.
        For categorical columns this is the value to be used to represent missingness as a category.

        Args:
            column_metadata: The metadata for the column.

        Returns:
            The missingness carrier for the column.
        """
        missingness_carrier = getattr(column_metadata.missingness_strategy, "missingness_carrier", None)
        if missingness_carrier in self.post_missingness_strategy_dataset.columns:
            return self.post_missingness_strategy_dataset[missingness_carrier]
        else:
            return missingness_carrier

    def _get_adherence_constraint(self, df) -> Union[pd.Series, Any]:
        adherence_columns = [col for col in df.columns if col.endswith("_adherence")]
        constraint_adherence = df[adherence_columns].prod(axis=1).astype(int)

        return constraint_adherence

    def _call_transformer_apply(
        self,
        transformer,
        *,
        series,
        constraint_adherence=None,
        missingness_column=None,
    ):
        """
        Call transformer.apply with only the kwargs it supports, using **keywords only**.
        - Binds the input series to 'data' if present, otherwise to the first non-self parameter.
        - Only passes constraint_adherence / missingness_column if the transformer declares them.
        """
        fn = transformer.apply
        sig = inspect.signature(fn)
        params = sig.parameters

        # Decide which parameter name to bind the series to
        if "data" in params:
            data_param = "data"
        else:
            data_param = next((n for n, p in params.items() if n != "self"), None)
            if data_param is None:
                # Last resort: the method takes no args beyond self; try calling without any
                return fn()

        kwargs = {data_param: series}

        if "constraint_adherence" in params:
            kwargs["constraint_adherence"] = constraint_adherence
        if "missingness_column" in params:
            kwargs["missingness_column"] = missingness_column

        return fn(**kwargs)

    def transform(self) -> pd.DataFrame:
        """
        Apply each column transformer to its *raw* Series, then concatenate results.
        Ensures each transformer receives a single Series (not the whole/mutated DataFrame),
        which fixes DateTime ('dob') KeyErrors.
        """

        # Prefer the dataset that already has missingness flags computed,
        # but still contains the original raw columns.
        if hasattr(self, "post_missingness_strategy_dataset") and self.post_missingness_strategy_dataset is not None:
            source_df = self.post_missingness_strategy_dataset
        elif hasattr(self, "typed_dataset") and self.typed_dataset is not None:
            source_df = self.typed_dataset
        else:
            # Fallback: raw dataset as last resort
            source_df = self._raw_dataset

        parts = []

        # Helper: get a column if it exists, else None
        def _maybe_col(df: pd.DataFrame, name: str):
            return df[name] if (df is not None and name in df.columns) else None

        for col_meta in self._metadata:
            # Work out the original column name this transformer handles
            col = (
                getattr(col_meta, "name", None)
                or getattr(col_meta, "column", None)
                or getattr(col_meta, "feature", None)
            )
            if col is None:
                raise ValueError(f"Metadata entry missing column name: {col_meta}")

            # Always hand the transformer a *Series* from the original (pre-transform) frame
            if col not in source_df.columns:
                raise KeyError(
                    f"[MetaTransformer.transform] Expected raw column '{col}' in source_df; "
                    f"available={list(source_df.columns)[:15]}..."
                )
            series = source_df[col]

            # Optional per-row flags
            # Missingness: prefer exact "{col}_missing"; if not present, try to find any "<col>_missing*"
            miss = _maybe_col(source_df, f"{col}_missing")
            if miss is None:
                # Try a looser match if you have variant names
                candidates = [c for c in source_df.columns if c.startswith(f"{col}_missing")]
                miss = source_df[candidates[0]] if candidates else None

            # Constraint adherence: if your PR1 kept these in constrained_dataset, pass it through; else use ones
            if hasattr(self, "constrained_dataset") and self.constrained_dataset is not None:
                adh = _maybe_col(self.constrained_dataset, f"{col}_adherence")
            else:
                adh = None

            if adh is None:
                # Default to all-ones (i.e., include all rows during transform)
                adh = pd.Series(1, index=series.index, name=f"{col}_adherence", dtype=int)

            # Apply the per-column transformer
            part = self._call_transformer_apply(
                col_meta.transformer,
                series=series,
                constraint_adherence=adh,
                missingness_column=miss,
            )

            # Normalise to DataFrame
            if isinstance(part, pd.Series):
                part = part.to_frame()

            value_idx_in_part = None
            for cand in (f"{col}_value", f"{col}_normalized", f"{col}_normalised"):
                if cand in part.columns:
                    value_idx_in_part = part.columns.get_loc(cand)
                    break

            # record absolute index for this value column
            if not hasattr(self, "continuous_value_indices"):
                self.continuous_value_indices = []

            abs_offset = sum(m.shape[1] if hasattr(m, "shape") else 1 for m in parts)  # cols before this part
            if value_idx_in_part is not None:
                self.continuous_value_indices.append(abs_offset + value_idx_in_part)

            parts.append(part)

        # Concatenate all transformed parts
        transformed = pd.concat(parts, axis=1)

        multi_groups: list[list[int]] = []
        single_list: list[int] = []

        col_offset = 0
        for part in parts:
            m = part.shape[1] if hasattr(part, "shape") else 1
            if m > 1:
                part_cols = list(part.columns)
                component_indices = []
                non_component_indices = []

                import re

                for i, col_name in enumerate(part_cols):
                    col_idx = col_offset + i
                    # Check if this is a single-column type (z-score or missingness)
                    is_single_col = isinstance(col_name, str) and any(
                        suffix in col_name for suffix in ["_value", "_normalized", "_normalised", "_missing"]
                    )
                    if is_single_col:
                        # Z-scores and missingness go to single_column_indices
                        non_component_indices.append(col_idx)
                    elif isinstance(col_name, str) and re.search(r"_c\d+$", col_name):
                        # GMM component columns (e.g., x7_c1, x7_c10) go to multi_column_indices
                        component_indices.append(col_idx)
                    elif isinstance(col_name, str):
                        # OHE categorical columns (e.g., x1_0.0, x3_A) go to multi_column_indices
                        component_indices.append(col_idx)
                    else:
                        # Fallback for non-string column names
                        non_component_indices.append(col_idx)

                if component_indices:
                    multi_groups.append(component_indices)
                single_list.extend(non_component_indices)
            else:
                single_list.append(col_offset)
            col_offset += m

        # Store on self for the model
        self.multi_column_indices = multi_groups
        self.single_column_indices = single_list
        self.output_columns = list(transformed.columns)
        self.ncols = transformed.shape[1]
        self.continuous_value_indices = list(self.continuous_value_indices)

        # Make sure downstream code (like VAE.generate) has the correct column names
        self.columns = list(transformed.columns)

        return transformed

    def apply(self) -> pd.DataFrame:
        """
        Applies the various steps of the MetaTransformer to a passed DataFrame.

        Returns:
            The transformed dataset.
        """
        self.drop_columns()
        self.typed_dataset = self.apply_dtypes(self._raw_dataset)
        self.post_missingness_strategy_dataset = self.apply_missingness_strategy()
        self.constrained_dataset = self.apply_constraints()
        self.transformed_dataset = self.transform()
        return self.transformed_dataset

    def inverse_apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
        """
        Reverses the transformation applied by the MetaTransformer.

        Args:
            dataset: The transformed dataset.

        Returns:
            The original dataset.
        """
        import numpy as np

        # binarize generated missingness indicators: >0.5 -> 1, else 0
        for col in list(dataset.columns):
            if col.endswith("_missing"):
                v = pd.to_numeric(dataset[col], errors="coerce").fillna(0.0).to_numpy()
                dataset[col] = (v > 0.5).astype(int)

        for column_metadata in self._metadata:
            dataset = column_metadata.transformer.revert(dataset)

        # Add Gaussian smoothing to continuous variables to blur GMM component peaks
        try:
            smoothing_std = 0.03  # 3% of column std as smoothing noise
            continuous_cols = []
            for col_meta in self._metadata:
                # Skip datetime columns, categorical columns, and missingness indicators
                if (
                    hasattr(col_meta, "transformer")
                    and col_meta.name in dataset.columns
                    and not col_meta.name.endswith("_missing")
                    and dataset[col_meta.name].dtype in ["float64", "float32", "int64", "int32"]
                ):
                    continuous_cols.append(col_meta.name)

            if continuous_cols:
                for col in continuous_cols:
                    col_std = dataset[col].std()
                    if col_std > 0:
                        noise = np.random.normal(0, smoothing_std * col_std, size=len(dataset))
                        dataset[col] = dataset[col] + noise
        except Exception as e:
            if DEBUG_VERBOSE:
                tqdm.write(f"[inverse_apply] WARNING: Gaussian smoothing failed: {e}")

        # Enforce constraints on decoded data if available
        try:
            dataset = self.repair_constraints(dataset, mode="resample")
        except Exception as e:
            if DEBUG_VERBOSE:
                tqdm.write(f"[inverse_apply] ERROR calling repair_constraints: {type(e).__name__}: {e}")
                import traceback

                tqdm.write(traceback.format_exc())

        out = self.apply_dtypes(dataset)
        return out

    def get_typed_dataset(self) -> pd.DataFrame:
        if not hasattr(self, "typed_dataset"):
            raise ValueError(
                "The typed dataset has not yet been created. Call `mt.apply()` (or `mt.apply_dtypes()`) first."
            )
        return self.typed_dataset

    def get_prepared_dataset(self) -> pd.DataFrame:
        if not hasattr(self, "prepared_dataset"):
            raise ValueError(
                "The prepared dataset has not yet been created. Call `mt.apply()` (or `mt.apply_missingness_strategy()`) first."
            )
        return self.prepared_dataset

    def get_transformed_dataset(self) -> pd.DataFrame:
        if not hasattr(self, "transformed_dataset"):
            raise ValueError(
                "The prepared dataset has not yet been created. Call `mt.apply()` (or `mt.transform()`) first."
            )
        return self.transformed_dataset

    def get_multi_and_single_column_indices(self) -> tuple[list[int], list[int]]:
        """
        Returns the indices of the columns that were transformed into one or multiple column(s).

        Returns:
            A tuple containing the indices of the single and multi columns.
        """
        if not hasattr(self, "multi_column_indices") or not hasattr(self, "single_column_indices"):
            raise ValueError(
                "The single and multi column indices have not yet been created. Call `mt.apply()` (or `mt.transform()`) first."
            )
        return self.multi_column_indices, self.single_column_indices

    def get_sdv_metadata(self) -> dict[str, dict[str, Any]]:
        """
        Calls the `MetaData` method to reformat its contents into the correct format for use with SDMetrics.

        Returns:
            The metadata in the correct format for SDMetrics.
        """
        return self._metadata.get_sdv_metadata()

    def save_metadata(self, path: pathlib.Path, collapse_yaml: bool = False) -> None:
        return self._metadata.save(path, collapse_yaml)

    def save_constraint_graphs(self, path: pathlib.Path) -> None:
        return self._metadata.constraints._output_graphs_html(path)

apply()

Applies the various steps of the MetaTransformer to a passed DataFrame.

Returns:

Type Description
DataFrame

The transformed dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply(self) -> pd.DataFrame:
    """
    Applies the various steps of the MetaTransformer to a passed DataFrame.

    Returns:
        The transformed dataset.
    """
    self.drop_columns()
    self.typed_dataset = self.apply_dtypes(self._raw_dataset)
    self.post_missingness_strategy_dataset = self.apply_missingness_strategy()
    self.constrained_dataset = self.apply_constraints()
    self.transformed_dataset = self.transform()
    return self.transformed_dataset

apply_dtypes(data)

Applies dtypes from the metadata to dataset.

Returns:

Type Description
DataFrame

The dataset with the dtypes applied.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply_dtypes(self, data: pd.DataFrame) -> pd.DataFrame:
    """
    Applies dtypes from the metadata to `dataset`.

    Returns:
        The dataset with the dtypes applied.
    """
    working_data = data.copy()
    for column_metadata in self._metadata:
        working_data[column_metadata.name] = self._apply_dtype(working_data[column_metadata.name], column_metadata)
    return working_data

apply_missingness_strategy()

Resolves missingness in the dataset via the MetaTransformer's global missingness strategy or column-wise missingness strategies. In the case of the AugmentMissingnessStrategy, the missingness is not resolved, instead a new column / value is added for later transformation.

Returns:

Type Description
DataFrame

The dataset with the missingness strategies applied.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def apply_missingness_strategy(self) -> pd.DataFrame:
    """
    Resolves missingness in the dataset via the `MetaTransformer`'s global missingness strategy or
    column-wise missingness strategies. In the case of the `AugmentMissingnessStrategy`, the missingness
    is not resolved, instead a new column / value is added for later transformation.

    Returns:
        The dataset with the missingness strategies applied.
    """
    working_data = self.typed_dataset.copy()
    for column_metadata in self._metadata:
        if not column_metadata.missingness_strategy:
            column_metadata.missingness_strategy = (
                self._missingness_strategy(self._impute_value)
                if hasattr(self, "_impute_value")
                else self._missingness_strategy()
            )
        if not working_data[column_metadata.name].isnull().any():
            continue
        working_data = column_metadata.missingness_strategy.remove(working_data, column_metadata)
    return working_data

drop_columns()

Drops columns from the dataset that are not in the MetaData.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def drop_columns(self) -> None:
    """
    Drops columns from the dataset that are not in the `MetaData`.
    """
    self._raw_dataset = self._raw_dataset[self._metadata.columns]

from_dict(dataset, metadata, **kwargs) classmethod

Instantiates a MetaTransformer from a metadata dictionary.

Parameters:

Name Type Description Default
dataset DataFrame

The raw input DataFrame.

required
metadata dict

A dictionary of raw metadata.

required

Returns:

Type Description
Self

A MetaTransformer object.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
@classmethod
def from_dict(cls, dataset: pd.DataFrame, metadata: dict, **kwargs) -> Self:
    """
    Instantiates a MetaTransformer from a metadata dictionary.

    Args:
        dataset: The raw input DataFrame.
        metadata: A dictionary of raw metadata.

    Returns:
        A MetaTransformer object.
    """
    return cls(dataset, MetaData(dataset, metadata), **kwargs)

from_path(dataset, metadata_path, **kwargs) classmethod

Instantiates a MetaTransformer from a metadata file via a provided path.

Parameters:

Name Type Description Default
dataset DataFrame

The raw input DataFrame.

required
metadata_path str

The path to the metadata file.

required

Returns:

Type Description
Self

A MetaTransformer object.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
@classmethod
def from_path(cls, dataset: pd.DataFrame, metadata_path: str, **kwargs) -> Self:
    """
    Instantiates a MetaTransformer from a metadata file via a provided path.

    Args:
        dataset: The raw input DataFrame.
        metadata_path: The path to the metadata file.

    Returns:
        A MetaTransformer object.
    """
    return cls(dataset, MetaData.from_path(dataset, metadata_path), **kwargs)

get_multi_and_single_column_indices()

Returns the indices of the columns that were transformed into one or multiple column(s).

Returns:

Type Description
tuple[list[int], list[int]]

A tuple containing the indices of the single and multi columns.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def get_multi_and_single_column_indices(self) -> tuple[list[int], list[int]]:
    """
    Returns the indices of the columns that were transformed into one or multiple column(s).

    Returns:
        A tuple containing the indices of the single and multi columns.
    """
    if not hasattr(self, "multi_column_indices") or not hasattr(self, "single_column_indices"):
        raise ValueError(
            "The single and multi column indices have not yet been created. Call `mt.apply()` (or `mt.transform()`) first."
        )
    return self.multi_column_indices, self.single_column_indices

get_sdv_metadata()

Calls the MetaData method to reformat its contents into the correct format for use with SDMetrics.

Returns:

Type Description
dict[str, dict[str, Any]]

The metadata in the correct format for SDMetrics.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def get_sdv_metadata(self) -> dict[str, dict[str, Any]]:
    """
    Calls the `MetaData` method to reformat its contents into the correct format for use with SDMetrics.

    Returns:
        The metadata in the correct format for SDMetrics.
    """
    return self._metadata.get_sdv_metadata()

inverse_apply(dataset)

Reverses the transformation applied by the MetaTransformer.

Parameters:

Name Type Description Default
dataset DataFrame

The transformed dataset.

required

Returns:

Type Description
DataFrame

The original dataset.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def inverse_apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
    """
    Reverses the transformation applied by the MetaTransformer.

    Args:
        dataset: The transformed dataset.

    Returns:
        The original dataset.
    """
    import numpy as np

    # binarize generated missingness indicators: >0.5 -> 1, else 0
    for col in list(dataset.columns):
        if col.endswith("_missing"):
            v = pd.to_numeric(dataset[col], errors="coerce").fillna(0.0).to_numpy()
            dataset[col] = (v > 0.5).astype(int)

    for column_metadata in self._metadata:
        dataset = column_metadata.transformer.revert(dataset)

    # Add Gaussian smoothing to continuous variables to blur GMM component peaks
    try:
        smoothing_std = 0.03  # 3% of column std as smoothing noise
        continuous_cols = []
        for col_meta in self._metadata:
            # Skip datetime columns, categorical columns, and missingness indicators
            if (
                hasattr(col_meta, "transformer")
                and col_meta.name in dataset.columns
                and not col_meta.name.endswith("_missing")
                and dataset[col_meta.name].dtype in ["float64", "float32", "int64", "int32"]
            ):
                continuous_cols.append(col_meta.name)

        if continuous_cols:
            for col in continuous_cols:
                col_std = dataset[col].std()
                if col_std > 0:
                    noise = np.random.normal(0, smoothing_std * col_std, size=len(dataset))
                    dataset[col] = dataset[col] + noise
    except Exception as e:
        if DEBUG_VERBOSE:
            tqdm.write(f"[inverse_apply] WARNING: Gaussian smoothing failed: {e}")

    # Enforce constraints on decoded data if available
    try:
        dataset = self.repair_constraints(dataset, mode="resample")
    except Exception as e:
        if DEBUG_VERBOSE:
            tqdm.write(f"[inverse_apply] ERROR calling repair_constraints: {type(e).__name__}: {e}")
            import traceback

            tqdm.write(traceback.format_exc())

    out = self.apply_dtypes(dataset)
    return out

repair_constraints(df, *, mode='reflect', rng=None, n_retries=0)

Enforce constraints on a decoded DataFrame using self._metadata.constraints.minimal_constraints.

Supports: - Numeric constant constraints (e.g., x > 0, x in (0, 100)) - Column reference constraints (e.g., x8 > x10)

Parameters:

Name Type Description Default
df DataFrame

DataFrame to repair

required
mode str

Repair strategy ('reflect', 'resample', 'clamp')

'reflect'
rng

Random number generator

None
n_retries int

Number of retries (unused for now)

0

Returns:

Type Description
DataFrame

DataFrame with constraints enforced

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def repair_constraints(
    self, df: pd.DataFrame, *, mode: str = "reflect", rng=None, n_retries: int = 0
) -> pd.DataFrame:
    """
    Enforce constraints on a *decoded* DataFrame using
    self._metadata.constraints.minimal_constraints.

    Supports:
    - Numeric constant constraints (e.g., x > 0, x in (0, 100))
    - Column reference constraints (e.g., x8 > x10)

    Args:
        df: DataFrame to repair
        mode: Repair strategy ('reflect', 'resample', 'clamp')
        rng: Random number generator
        n_retries: Number of retries (unused for now)

    Returns:
        DataFrame with constraints enforced
    """
    import numpy as np
    import pandas as pd

    if rng is None:
        rng = np.random.default_rng()

    constraints = getattr(getattr(self, "_metadata", None), "constraints", None)
    if constraints is None:
        return df
    constraints_iterable = getattr(constraints, "minimal_constraints", None)
    if constraints_iterable is None:
        return df

    repaired = df.copy()
    violations_fixed = 0

    # Import ComboConstraint to check instance
    from nhssynth.modules.dataloader.constraints import ConstraintGraph

    # Convert to list to avoid consuming iterator
    constraints_list = list(constraints_iterable)

    # Iterate through minimal constraints and repair violations
    for constraint in constraints_list:
        # Skip combo constraints for now (they require special handling)
        if isinstance(constraint, ConstraintGraph.ComboConstraint):
            continue

        base = constraint.base
        operator = constraint.operator
        reference = constraint.reference
        is_column = constraint.reference_is_column

        # Skip if base column doesn't exist in dataframe
        if base not in repaired.columns:
            continue

        # Get base values
        base_vals = pd.to_numeric(repaired[base], errors="coerce")

        # Handle column reference constraints (e.g., x8 > x10)
        if is_column:
            if reference not in repaired.columns:
                continue
            ref_vals = pd.to_numeric(repaired[reference], errors="coerce")
        else:
            # Handle numeric constant constraints (e.g., x > 0)
            try:
                ref_vals = float(reference)
            except ValueError:
                # Could be datetime, skip for now
                continue

        # Identify violations
        if operator == ">":
            violations = base_vals <= ref_vals
        elif operator == ">=":
            violations = base_vals < ref_vals
        elif operator == "<":
            violations = base_vals >= ref_vals
        elif operator == "<=":
            violations = base_vals > ref_vals
        else:
            continue  # Unknown operator

        # Only repair finite violations (skip NaN)
        violations = violations & np.isfinite(base_vals)

        n_viol = violations.sum()
        if n_viol == 0:
            continue

        violations_fixed += n_viol

        # Repair strategy
        if operator in [">", ">="]:
            # Base must be greater than reference
            if is_column:
                # Add small margin above reference (5% of reference value or 0.1, whichever is larger)
                margin = np.maximum(0.1, 0.05 * np.abs(ref_vals[violations]))
                if operator == ">":
                    repaired.loc[violations, base] = ref_vals[violations] + margin
                else:  # >=
                    repaired.loc[violations, base] = ref_vals[violations]
            else:
                # Numeric constant
                margin = max(0.1, 0.05 * abs(ref_vals))
                if operator == ">":
                    repaired.loc[violations, base] = ref_vals + margin
                else:  # >=
                    repaired.loc[violations, base] = ref_vals

        elif operator in ["<", "<="]:
            # Base must be less than reference
            if is_column:
                margin = np.maximum(0.1, 0.05 * np.abs(ref_vals[violations]))
                if operator == "<":
                    repaired.loc[violations, base] = ref_vals[violations] - margin
                else:  # <=
                    repaired.loc[violations, base] = ref_vals[violations]
            else:
                margin = max(0.1, 0.05 * abs(ref_vals))
                if operator == "<":
                    repaired.loc[violations, base] = ref_vals - margin
                else:  # <=
                    repaired.loc[violations, base] = ref_vals

    if DEBUG_VERBOSE and violations_fixed > 0:
        tqdm.write(f"[repair_constraints] Fixed {violations_fixed} constraint violations")

    return repaired

transform()

Apply each column transformer to its raw Series, then concatenate results. Ensures each transformer receives a single Series (not the whole/mutated DataFrame), which fixes DateTime ('dob') KeyErrors.

Source code in src/nhssynth/modules/dataloader/metatransformer.py
def transform(self) -> pd.DataFrame:
    """
    Apply each column transformer to its *raw* Series, then concatenate results.
    Ensures each transformer receives a single Series (not the whole/mutated DataFrame),
    which fixes DateTime ('dob') KeyErrors.
    """

    # Prefer the dataset that already has missingness flags computed,
    # but still contains the original raw columns.
    if hasattr(self, "post_missingness_strategy_dataset") and self.post_missingness_strategy_dataset is not None:
        source_df = self.post_missingness_strategy_dataset
    elif hasattr(self, "typed_dataset") and self.typed_dataset is not None:
        source_df = self.typed_dataset
    else:
        # Fallback: raw dataset as last resort
        source_df = self._raw_dataset

    parts = []

    # Helper: get a column if it exists, else None
    def _maybe_col(df: pd.DataFrame, name: str):
        return df[name] if (df is not None and name in df.columns) else None

    for col_meta in self._metadata:
        # Work out the original column name this transformer handles
        col = (
            getattr(col_meta, "name", None)
            or getattr(col_meta, "column", None)
            or getattr(col_meta, "feature", None)
        )
        if col is None:
            raise ValueError(f"Metadata entry missing column name: {col_meta}")

        # Always hand the transformer a *Series* from the original (pre-transform) frame
        if col not in source_df.columns:
            raise KeyError(
                f"[MetaTransformer.transform] Expected raw column '{col}' in source_df; "
                f"available={list(source_df.columns)[:15]}..."
            )
        series = source_df[col]

        # Optional per-row flags
        # Missingness: prefer exact "{col}_missing"; if not present, try to find any "<col>_missing*"
        miss = _maybe_col(source_df, f"{col}_missing")
        if miss is None:
            # Try a looser match if you have variant names
            candidates = [c for c in source_df.columns if c.startswith(f"{col}_missing")]
            miss = source_df[candidates[0]] if candidates else None

        # Constraint adherence: if your PR1 kept these in constrained_dataset, pass it through; else use ones
        if hasattr(self, "constrained_dataset") and self.constrained_dataset is not None:
            adh = _maybe_col(self.constrained_dataset, f"{col}_adherence")
        else:
            adh = None

        if adh is None:
            # Default to all-ones (i.e., include all rows during transform)
            adh = pd.Series(1, index=series.index, name=f"{col}_adherence", dtype=int)

        # Apply the per-column transformer
        part = self._call_transformer_apply(
            col_meta.transformer,
            series=series,
            constraint_adherence=adh,
            missingness_column=miss,
        )

        # Normalise to DataFrame
        if isinstance(part, pd.Series):
            part = part.to_frame()

        value_idx_in_part = None
        for cand in (f"{col}_value", f"{col}_normalized", f"{col}_normalised"):
            if cand in part.columns:
                value_idx_in_part = part.columns.get_loc(cand)
                break

        # record absolute index for this value column
        if not hasattr(self, "continuous_value_indices"):
            self.continuous_value_indices = []

        abs_offset = sum(m.shape[1] if hasattr(m, "shape") else 1 for m in parts)  # cols before this part
        if value_idx_in_part is not None:
            self.continuous_value_indices.append(abs_offset + value_idx_in_part)

        parts.append(part)

    # Concatenate all transformed parts
    transformed = pd.concat(parts, axis=1)

    multi_groups: list[list[int]] = []
    single_list: list[int] = []

    col_offset = 0
    for part in parts:
        m = part.shape[1] if hasattr(part, "shape") else 1
        if m > 1:
            part_cols = list(part.columns)
            component_indices = []
            non_component_indices = []

            import re

            for i, col_name in enumerate(part_cols):
                col_idx = col_offset + i
                # Check if this is a single-column type (z-score or missingness)
                is_single_col = isinstance(col_name, str) and any(
                    suffix in col_name for suffix in ["_value", "_normalized", "_normalised", "_missing"]
                )
                if is_single_col:
                    # Z-scores and missingness go to single_column_indices
                    non_component_indices.append(col_idx)
                elif isinstance(col_name, str) and re.search(r"_c\d+$", col_name):
                    # GMM component columns (e.g., x7_c1, x7_c10) go to multi_column_indices
                    component_indices.append(col_idx)
                elif isinstance(col_name, str):
                    # OHE categorical columns (e.g., x1_0.0, x3_A) go to multi_column_indices
                    component_indices.append(col_idx)
                else:
                    # Fallback for non-string column names
                    non_component_indices.append(col_idx)

            if component_indices:
                multi_groups.append(component_indices)
            single_list.extend(non_component_indices)
        else:
            single_list.append(col_offset)
        col_offset += m

    # Store on self for the model
    self.multi_column_indices = multi_groups
    self.single_column_indices = single_list
    self.output_columns = list(transformed.columns)
    self.ncols = transformed.shape[1]
    self.continuous_value_indices = list(self.continuous_value_indices)

    # Make sure downstream code (like VAE.generate) has the correct column names
    self.columns = list(transformed.columns)

    return transformed