Skip to content

datetime

DatetimeTransformer

Bases: TransformerWrapper

A transformer to convert datetime features to numeric features. Before applying an underlying (wrapped) transformer. The datetime features are converted to nanoseconds since the epoch, and missing values are assigned to 0.0 under the AugmentMissingnessStrategy.

Parameters:

Name Type Description Default
transformer ColumnTransformer

The ColumnTransformer to wrap.

required

After applying the transformer, the following attributes will be populated:

Attributes:

Name Type Description
original_column_name

The name of the original column.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
class DatetimeTransformer(TransformerWrapper):
    """
    A transformer to convert datetime features to numeric features. Before applying an underlying (wrapped) transformer.
    The datetime features are converted to nanoseconds since the epoch, and missing values are assigned to 0.0 under the `AugmentMissingnessStrategy`.

    Args:
        transformer: The [`ColumnTransformer`][nhssynth.modules.dataloader.transformers.base.ColumnTransformer] to wrap.

    After applying the transformer, the following attributes will be populated:

    Attributes:
        original_column_name: The name of the original column.
    """

    def __init__(self, transformer: ColumnTransformer) -> None:
        super().__init__(transformer)

    def apply(
        self,
        data: pd.Series,
        constraint_adherence: Optional[pd.Series],
        missingness_column: Optional[pd.Series] = None,
    ) -> pd.DataFrame:
        """
        Encode datetimes by converting to numeric *days* since epoch, then delegate
        to the continuous mixture transformer. Also caches train-window bounds and a
        small reservoir for later repair/jitter.
        """
        import pandas as pd
        from tqdm import tqdm

        # Keep original name for revert
        self.original_column_name = data.name

        # Coerce to datetime64[ns] and to *float* nanoseconds (NaN for missing)
        dt = pd.to_datetime(data, errors="coerce")
        # Use astype instead of deprecated view() method
        ns = dt.astype("int64").astype("float64")
        ns[dt.isna()] = np.nan

        # Convert to DAYS (keeps σ in a reasonable range for mixture)
        NS_PER_DAY = float(24 * 60 * 60 * 1e9)
        self._unit_scale = NS_PER_DAY  # used later in revert
        days = ns / NS_PER_DAY

        # ---- Cache train-window bounds & a small pool (in *days*) ----
        # Work only on non-missing rows for bounds
        days_train = days[missingness_column == 0] if missingness_column is not None else days
        vals = pd.to_numeric(days_train, errors="coerce").to_numpy(dtype="float64", copy=False)
        vals = vals[np.isfinite(vals)]

        if vals.size >= 10:
            p1 = float(np.nanpercentile(vals, 1))
            p99 = float(np.nanpercentile(vals, 99))
            window = p99 - p1

            # Cache bounds in *days* AND *ns* (revert uses ns)
            self._days_min = p1 if window > 0 else None
            self._days_max = p99 if window > 0 else None
            # Store NS versions for revert (FIX: revert was looking for _ns_min/max but we only set _days_min/max)
            self._ns_min = p1 * NS_PER_DAY if window > 0 else None
            self._ns_max = p99 * NS_PER_DAY if window > 0 else None

            # Small reservoir (in-bounds, in *days*) for repair/jitter
            pool = vals[(vals >= p1) & (vals <= p99)]
            if pool.size > 5000:
                rng = np.random.default_rng(0)
                pool = rng.choice(pool, size=5000, replace=False)
            self._days_pool = pool.astype("float64", copy=False)
            # Store NS version for revert (FIX: revert was looking for _ns_pool)
            self._ns_pool = (pool * NS_PER_DAY).astype("float64", copy=False)

            # Enable clamp only for sensible windows (>= 30 days)
            self._days_clamp_enabled = bool(window >= 30.0)
            self._ns_clamp_enabled = bool(window >= 30.0)  # FIX: revert uses _ns_clamp_enabled

            # Debug
            p1_ts = pd.to_datetime(int(round(p1 * NS_PER_DAY)), unit="ns", errors="coerce")
            p99_ts = pd.to_datetime(int(round(p99 * NS_PER_DAY)), unit="ns", errors="coerce")
            tqdm.write(
                f"[datetime.apply] {self.original_column_name}: p1={p1_ts}  "
                f"p99={p99_ts}  Δ≈{window:.1f} days  "
                f"clamp={'ON' if self._days_clamp_enabled else 'OFF'}  pool={self._days_pool.size}"
            )
        else:
            self._days_min = self._days_max = None
            self._ns_min = self._ns_max = None  # FIX: also set ns versions
            self._days_pool = np.array([], dtype="float64")
            self._ns_pool = np.array([], dtype="float64")  # FIX: also set ns version
            self._days_clamp_enabled = False
            self._ns_clamp_enabled = False  # FIX: also set ns version
            tqdm.write(f"[datetime.apply] {self.original_column_name}: insufficient data; clamp OFF, empty pool")

        # Hand off DAYS series to the mixture/continuous transformer
        days_series = pd.Series(days, index=data.index, name=self.original_column_name)
        return super().apply(
            data=days_series,
            constraint_adherence=constraint_adherence,
            missingness_column=missingness_column,
        )

    def revert(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
        """
        Decode from mixture space back to *days*, clamp/jitter in *days* (if enabled),
        convert to ns and then to datetime64[ns]. Respects <base>_missing when present.
        """
        import pandas as pd
        from tqdm import tqdm

        # 1) Let the continuous transformer put the numeric DAYS back
        reverted = super().revert(data, **kwargs)

        base = getattr(self, "original_column_name", None) or getattr(self, "name", None)
        if not base or base not in reverted.columns:
            return reverted

        # 2) DAYS as float
        days = pd.to_numeric(reverted[base], errors="coerce").astype("float64")

        # 3) Convert to ns for clamping (FIX: need to define ns_vals before using it)
        NS_PER_DAY = float(getattr(self, "_unit_scale", 24 * 60 * 60 * 1e9))
        ns_vals = days * NS_PER_DAY

        # 3b) Optional clamp+tiny jitter in *ns*
        # --- replace the hard clip with pool-based resample + light jitter ---
        ns_min = getattr(self, "_ns_min", None)
        ns_max = getattr(self, "_ns_max", None)
        clamp_ok = bool(getattr(self, "_ns_clamp_enabled", False)) and (
            ns_min is not None and ns_max is not None and ns_max > ns_min
        )

        if clamp_ok:
            from tqdm import tqdm

            rng = np.random.default_rng()

            # out-of-bounds mask
            bad = np.isfinite(ns_vals) & ((ns_vals < ns_min) | (ns_vals > ns_max))
            tqdm.write(
                f"[datetime.revert] window={pd.to_datetime(ns_min)}..{pd.to_datetime(ns_max)} "
                f"oob={int(bad.sum())}/{len(ns_vals)}"
            )

            if bad.any():
                pool = getattr(self, "_ns_pool", None)
                if isinstance(pool, np.ndarray) and pool.size:
                    choice = rng.choice(pool, size=int(bad.sum()), replace=True).astype("float64")
                else:
                    choice = rng.uniform(ns_min, ns_max, size=int(bad.sum()))
                # small jitter so they don’t quantize to identical ticks
                J = int(max(1e8, 0.0005 * (ns_max - ns_min)))  # ≥0.1s or 0.05% window
                choice = choice + rng.integers(-J, J + 1, size=choice.shape)
                ns_vals[bad] = choice

            # light jitter for in-bounds too, to avoid edge pile-ups
            good = np.isfinite(ns_vals) & ~bad
            if good.any():
                J_small = int(max(1e7, 0.0002 * (ns_max - ns_min)))  # ≥0.01s or 0.02% window
                ns_vals[good] = ns_vals[good] + rng.integers(-J_small, J_small + 1, size=int(good.sum()))

        # 4) Use the clamped/jittered ns_vals (already computed above)
        ns_float = ns_vals
        finite = np.isfinite(ns_float)

        # Prevent int64 overflow before rounding
        i64_min, i64_max = np.iinfo(np.int64).min, np.iinfo(np.int64).max
        ns_float = np.where(finite, np.clip(ns_float, i64_min, i64_max), np.nan)
        ns_round = np.where(finite, np.rint(ns_float), np.nan)

        # Nullable Int64 via pd.array (NaN -> <NA>)
        ns_int = pd.array(ns_round, dtype="Int64")

        # 5) To datetime64[ns]
        dt = pd.to_datetime(ns_int, unit="ns", errors="coerce")
        dt = pd.Series(dt, index=reverted.index, name=base)

        # 6) Respect missingness flag if present
        miss_col = f"{base}_missing"
        if miss_col in reverted.columns:
            m = pd.to_numeric(reverted[miss_col], errors="coerce").fillna(0).astype(bool).to_numpy()
            if m.any():
                dt.loc[m] = pd.NaT

        # 7) Write back; continuous.revert likely already dropped helper cols
        reverted[base] = dt
        return reverted

apply(data, constraint_adherence, missingness_column=None)

Encode datetimes by converting to numeric days since epoch, then delegate to the continuous mixture transformer. Also caches train-window bounds and a small reservoir for later repair/jitter.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
def apply(
    self,
    data: pd.Series,
    constraint_adherence: Optional[pd.Series],
    missingness_column: Optional[pd.Series] = None,
) -> pd.DataFrame:
    """
    Encode datetimes by converting to numeric *days* since epoch, then delegate
    to the continuous mixture transformer. Also caches train-window bounds and a
    small reservoir for later repair/jitter.
    """
    import pandas as pd
    from tqdm import tqdm

    # Keep original name for revert
    self.original_column_name = data.name

    # Coerce to datetime64[ns] and to *float* nanoseconds (NaN for missing)
    dt = pd.to_datetime(data, errors="coerce")
    # Use astype instead of deprecated view() method
    ns = dt.astype("int64").astype("float64")
    ns[dt.isna()] = np.nan

    # Convert to DAYS (keeps σ in a reasonable range for mixture)
    NS_PER_DAY = float(24 * 60 * 60 * 1e9)
    self._unit_scale = NS_PER_DAY  # used later in revert
    days = ns / NS_PER_DAY

    # ---- Cache train-window bounds & a small pool (in *days*) ----
    # Work only on non-missing rows for bounds
    days_train = days[missingness_column == 0] if missingness_column is not None else days
    vals = pd.to_numeric(days_train, errors="coerce").to_numpy(dtype="float64", copy=False)
    vals = vals[np.isfinite(vals)]

    if vals.size >= 10:
        p1 = float(np.nanpercentile(vals, 1))
        p99 = float(np.nanpercentile(vals, 99))
        window = p99 - p1

        # Cache bounds in *days* AND *ns* (revert uses ns)
        self._days_min = p1 if window > 0 else None
        self._days_max = p99 if window > 0 else None
        # Store NS versions for revert (FIX: revert was looking for _ns_min/max but we only set _days_min/max)
        self._ns_min = p1 * NS_PER_DAY if window > 0 else None
        self._ns_max = p99 * NS_PER_DAY if window > 0 else None

        # Small reservoir (in-bounds, in *days*) for repair/jitter
        pool = vals[(vals >= p1) & (vals <= p99)]
        if pool.size > 5000:
            rng = np.random.default_rng(0)
            pool = rng.choice(pool, size=5000, replace=False)
        self._days_pool = pool.astype("float64", copy=False)
        # Store NS version for revert (FIX: revert was looking for _ns_pool)
        self._ns_pool = (pool * NS_PER_DAY).astype("float64", copy=False)

        # Enable clamp only for sensible windows (>= 30 days)
        self._days_clamp_enabled = bool(window >= 30.0)
        self._ns_clamp_enabled = bool(window >= 30.0)  # FIX: revert uses _ns_clamp_enabled

        # Debug
        p1_ts = pd.to_datetime(int(round(p1 * NS_PER_DAY)), unit="ns", errors="coerce")
        p99_ts = pd.to_datetime(int(round(p99 * NS_PER_DAY)), unit="ns", errors="coerce")
        tqdm.write(
            f"[datetime.apply] {self.original_column_name}: p1={p1_ts}  "
            f"p99={p99_ts}  Δ≈{window:.1f} days  "
            f"clamp={'ON' if self._days_clamp_enabled else 'OFF'}  pool={self._days_pool.size}"
        )
    else:
        self._days_min = self._days_max = None
        self._ns_min = self._ns_max = None  # FIX: also set ns versions
        self._days_pool = np.array([], dtype="float64")
        self._ns_pool = np.array([], dtype="float64")  # FIX: also set ns version
        self._days_clamp_enabled = False
        self._ns_clamp_enabled = False  # FIX: also set ns version
        tqdm.write(f"[datetime.apply] {self.original_column_name}: insufficient data; clamp OFF, empty pool")

    # Hand off DAYS series to the mixture/continuous transformer
    days_series = pd.Series(days, index=data.index, name=self.original_column_name)
    return super().apply(
        data=days_series,
        constraint_adherence=constraint_adherence,
        missingness_column=missingness_column,
    )

revert(data, **kwargs)

Decode from mixture space back to days, clamp/jitter in days (if enabled), convert to ns and then to datetime64[ns]. Respects _missing when present.

Source code in src/nhssynth/modules/dataloader/transformers/datetime.py
def revert(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
    """
    Decode from mixture space back to *days*, clamp/jitter in *days* (if enabled),
    convert to ns and then to datetime64[ns]. Respects <base>_missing when present.
    """
    import pandas as pd
    from tqdm import tqdm

    # 1) Let the continuous transformer put the numeric DAYS back
    reverted = super().revert(data, **kwargs)

    base = getattr(self, "original_column_name", None) or getattr(self, "name", None)
    if not base or base not in reverted.columns:
        return reverted

    # 2) DAYS as float
    days = pd.to_numeric(reverted[base], errors="coerce").astype("float64")

    # 3) Convert to ns for clamping (FIX: need to define ns_vals before using it)
    NS_PER_DAY = float(getattr(self, "_unit_scale", 24 * 60 * 60 * 1e9))
    ns_vals = days * NS_PER_DAY

    # 3b) Optional clamp+tiny jitter in *ns*
    # --- replace the hard clip with pool-based resample + light jitter ---
    ns_min = getattr(self, "_ns_min", None)
    ns_max = getattr(self, "_ns_max", None)
    clamp_ok = bool(getattr(self, "_ns_clamp_enabled", False)) and (
        ns_min is not None and ns_max is not None and ns_max > ns_min
    )

    if clamp_ok:
        from tqdm import tqdm

        rng = np.random.default_rng()

        # out-of-bounds mask
        bad = np.isfinite(ns_vals) & ((ns_vals < ns_min) | (ns_vals > ns_max))
        tqdm.write(
            f"[datetime.revert] window={pd.to_datetime(ns_min)}..{pd.to_datetime(ns_max)} "
            f"oob={int(bad.sum())}/{len(ns_vals)}"
        )

        if bad.any():
            pool = getattr(self, "_ns_pool", None)
            if isinstance(pool, np.ndarray) and pool.size:
                choice = rng.choice(pool, size=int(bad.sum()), replace=True).astype("float64")
            else:
                choice = rng.uniform(ns_min, ns_max, size=int(bad.sum()))
            # small jitter so they don’t quantize to identical ticks
            J = int(max(1e8, 0.0005 * (ns_max - ns_min)))  # ≥0.1s or 0.05% window
            choice = choice + rng.integers(-J, J + 1, size=choice.shape)
            ns_vals[bad] = choice

        # light jitter for in-bounds too, to avoid edge pile-ups
        good = np.isfinite(ns_vals) & ~bad
        if good.any():
            J_small = int(max(1e7, 0.0002 * (ns_max - ns_min)))  # ≥0.01s or 0.02% window
            ns_vals[good] = ns_vals[good] + rng.integers(-J_small, J_small + 1, size=int(good.sum()))

    # 4) Use the clamped/jittered ns_vals (already computed above)
    ns_float = ns_vals
    finite = np.isfinite(ns_float)

    # Prevent int64 overflow before rounding
    i64_min, i64_max = np.iinfo(np.int64).min, np.iinfo(np.int64).max
    ns_float = np.where(finite, np.clip(ns_float, i64_min, i64_max), np.nan)
    ns_round = np.where(finite, np.rint(ns_float), np.nan)

    # Nullable Int64 via pd.array (NaN -> <NA>)
    ns_int = pd.array(ns_round, dtype="Int64")

    # 5) To datetime64[ns]
    dt = pd.to_datetime(ns_int, unit="ns", errors="coerce")
    dt = pd.Series(dt, index=reverted.index, name=base)

    # 6) Respect missingness flag if present
    miss_col = f"{base}_missing"
    if miss_col in reverted.columns:
        m = pd.to_numeric(reverted[miss_col], errors="coerce").fillna(0).astype(bool).to_numpy()
        if m.any():
            dt.loc[m] = pd.NaT

    # 7) Write back; continuous.revert likely already dropped helper cols
    reverted[base] = dt
    return reverted