Skip to content

constraints

ConstraintGraph

Source code in src/nhssynth/modules/dataloader/constraints.py
class ConstraintGraph:
    VALID_OPERATORS: Final = [">", ">=", "<", "<=", "in"]
    POSITIVITY_TO_OPERATOR: Final = {
        "positive": ">",
        "nonnegative": ">=",
        "negative": "<",
        "nonpositive": "<=",
    }
    BRACKET_TO_OPERATOR: Final = {"[": ">=", "]": "<=", "(": ">", ")": "<"}
    OPERATOR_TO_PANDAS: Final = {
        "<": pd.Series.lt,
        "<=": pd.Series.le,
        ">": pd.Series.gt,
        ">=": pd.Series.ge,
    }

    class Constraint:
        VALID_OPERATORS: Final = [">", ">=", "<", "<=", "in"]
        POSITIVITY_TO_OPERATOR: Final = {
            "positive": ">",
            "nonnegative": ">=",
            "negative": "<",
            "nonpositive": "<=",
        }
        BRACKET_TO_OPERATOR: Final = {"[": ">=", "]": "<=", "(": ">", ")": "<"}
        OPERATOR_TO_PANDAS: Final = {
            "<": pd.Series.lt,
            "<=": pd.Series.le,
            ">": pd.Series.gt,
            ">=": pd.Series.ge,
        }

        def __init__(
            self,
            base: str,
            operator: str,
            reference: Union[str, float],
            reference_is_column: bool = False,
        ):
            self.base = base
            self.operator = operator
            self.reference = reference
            self.reference_is_column = reference_is_column

        def __str__(self) -> str:
            return f"{self.base} {self.operator} {self.reference}"

        def __repr__(self) -> str:
            return str(self)

        def __eq__(self, other) -> bool:
            return (
                self.base == other.base
                and self.operator == other.operator
                and self.reference == other.reference
                and self.reference_is_column == other.reference_is_column
            )

        def transform(self, df):
            # Ensure that the base column exists in the DataFrame
            if self.base not in df.columns:
                raise ValueError(f"Column '{self.base}' not found in DataFrame.")

            # Handle float-based constraints (e.g., columnA > 10)
            if not self.reference_is_column:
                reference = float(self.reference)
                adherence = self.OPERATOR_TO_PANDAS[self.operator](df[self.base], reference)
            else:
                # Handle column-to-column constraints (e.g., columnB <= columnC)
                reference = df[self.reference]
                adherence = self.OPERATOR_TO_PANDAS[self.operator](df[self.base], reference)
            adherence = adherence.fillna(False)
            # Create a new column for adherence (boolean series)
            df[self.base + "_adherence"] = adherence.astype(int)  # Store adherence as 0 (False) or 1 (True)

            # Optionally calculate and store the difference for rows that don't meet the constraint
            # This is useful for identifying the "degree" to which the constraint is violated
            # diff = np.abs(df[self.base] - self.reference)
            #  diff[~adherence] = np.nan  # Set diff to NaN where adherence is False
            #  df[self.base + "_diff"] = diff

            return df

    class ComboConstraint:
        def __init__(self, columns: list[str]):
            self.columns = columns

        def __str__(self) -> str:
            return f"fixcombo {' '.join(self.columns)}"

        def __repr__(self) -> str:
            return str(self)

        def __eq__(self, other) -> bool:
            return self.columns == other.columns

        def transform(self, df):
            return df

    def __init__(self, constraint_strings: Optional[list[str]], columns: pd.Index, metadata: dict):
        self._columns = columns
        self._metadata = metadata
        self.raw_constraint_strings = constraint_strings
        self.validated_constraint_strings = self.validate_constraint_strings()
        self.graph = self.build_graph(self.validated_constraint_strings)
        self.minimal_constraints = self.determine_minimal_constraints()
        print("Minimal constraints resolved to:")
        print(self.minimal_constraints)
        print()
        self.minimal_graph = self.build_graph(
            [str(c).split(" ") for c in self.minimal_constraints if isinstance(c, self.Constraint)]
            + [
                ("fixcombo", str(c).split(" ")[1:])
                for c in self.minimal_constraints
                if isinstance(c, self.ComboConstraint)
            ]
        )

    def _validate_fixcombo_constraint(self, elements: list[str]) -> tuple[str, str]:
        for column in elements[1:]:
            self._column_exists(column)
            if not self._metadata[column].categorical:
                raise ValueError(f"'{column}' must be categorical to use the 'fixcombo' operator.")
        return ("fixcombo", elements[1:])

    def _column_exists(self, column: str) -> None:
        if column not in self._columns:
            raise ValueError(f"Constraint refers to a column that does not exist ('{column}').")

    def _validate_positivity(self, positivity: str) -> None:
        if positivity not in self.POSITIVITY_TO_OPERATOR:
            raise ValueError(f"Constraint has an invalid positivity specification ('{positivity}').")

    def _validate_simple_constraint(self, base: str, positivity: str) -> tuple[str, str, str]:
        self._column_exists(base)
        self._validate_positivity(positivity)
        return (base, self.POSITIVITY_TO_OPERATOR[positivity], "0")

    def _validate_operator(self, base: str, operator: str) -> None:
        if operator not in self.VALID_OPERATORS:
            raise ValueError(f"Constraint has an invalid operator ('{operator}').")
        if self._metadata[base].dtype.kind == "O":
            raise ValueError(
                f"Constraint's base column ('{base}') must be numeric or datetime when '{operator}' is used."
            )

    def _validate_matching_dtypes(self, base: str, reference: str) -> None:
        if self._metadata[base].dtype != self._metadata[reference].dtype:
            raise ValueError(
                f"Constraint's base column ('{base}') has a different dtype ({self._metadata[base].dtype.name}) to the reference column's ('{reference}': {self._metadata[reference].dtype.name}), which is not allowed."
            )
        if self._metadata[base].categorical or self._metadata[reference].categorical:
            raise ValueError(
                f"Constraint's base column ('{base}') and reference column ('{reference}') must both be numeric or datetime when using any operator other than 'fixcombo'."
            )

    def _get_range_operators(self, reference: str) -> tuple[str, str]:
        if reference[0] not in ["[", "("] or reference[-1] not in ["]", ")"]:
            raise ValueError(
                f"Constraint's reference is not a valid range specification ('{reference}'), it must be of the form '[' or '(' + 'a,b' + ']' or ')'."
            )
        return self.BRACKET_TO_OPERATOR[reference[0]], self.BRACKET_TO_OPERATOR[reference[-1]]

    def _validate_constant_dtype(self, base: str, reference: str) -> None:
        if self._metadata[base].dtype.kind == "O":
            raise ValueError(
                f"The reference ('{reference}') is not a valid dtype for the constraint's base column ('{base}': '{self._metadata[base].dtype}')."
            )
        elif self._metadata[base].dtype.kind == "M":
            try:
                pd.to_datetime(reference)
            except (ValueError, pd.DateParseError):
                raise ValueError(
                    f"The reference ('{reference}') is not a valid datetime to match the dtype of the constraint's base column ('{base}': '{self._metadata[base].dtype}')."
                )
        else:
            try:
                float(reference)
            except ValueError:
                raise ValueError(
                    f"The reference ('{reference}') is not a valid float to match the dtype of the constraint's base column ('{base}: '{self._metadata[base].dtype}')."
                )

    def _validate_range_component(self, base: str, component: str) -> str:
        component = component.strip()
        if component in self._columns:
            self._validate_matching_dtypes(base, component)
        else:
            self._validate_constant_dtype(base, component)
        return component

    def _validate_reference_constraint(self, base: str, operator: str, reference: str) -> list[tuple[str, str, str]]:
        self._column_exists(base)
        self._validate_operator(base, operator)
        if reference in self._columns:
            self._validate_matching_dtypes(base, reference)
        elif operator == "in":
            low_op, high_op = self._get_range_operators(reference)
            low, high = reference[1:-1].split(",")
            low = self._validate_range_component(base, low)
            high = self._validate_range_component(base, high)
            if low not in self._columns and high not in self._columns and float(low) >= float(high):
                raise ValueError(
                    f"Constraint's reference is not a valid range specification ('{reference}'), the lower bound must be strictly less than the upper bound."
                )
            return [(base, low_op, low), (base, high_op, high)]
        else:
            self._validate_constant_dtype(base, reference)
        return [(base, operator, reference)]

    def validate_constraint_strings(self) -> list[tuple[str, str, str]]:
        valid_constraints = []
        for constraint_string in self.raw_constraint_strings:
            elements = constraint_string.split(" ")
            if elements[0] == "fixcombo":
                valid_constraints.append(self._validate_fixcombo_constraint(elements))
            elif len(elements) == 2:
                valid_constraints.append(self._validate_simple_constraint(*elements))
            elif len(elements) == 3:
                valid_constraints.extend(self._validate_reference_constraint(*elements))
            else:
                raise ValueError(f"Constraint '{constraint_string}' is invalid.")
        return valid_constraints

    def build_graph(self, constraint_string_tuples) -> nx.DiGraph:
        graph = nx.DiGraph()
        for col in self._columns:
            graph.add_node(col, color="purple" if self._metadata[col].categorical else "blue")
        for cst in constraint_string_tuples:
            if cst[0] == "fixcombo":
                cols = cst[1]
                for i in range(len(cols) - 1):
                    graph.add_edge(cols[i], cols[i + 1], color="purple")
            else:
                item1, operator, item2 = cst
                if "<" in operator:
                    item1, item2 = item2, item1
                if item1 not in graph.nodes:
                    graph.add_node(item1)
                    graph.nodes[item1]["color"] = "red"
                if item2 not in graph.nodes:
                    graph.add_node(item2)
                    graph.nodes[item2]["color"] = "red"
                graph.add_edge(item1, item2, color="black" if len(operator) == 1 else "green")
                graph.add_edge(item1, item2, color="black" if len(operator) == 1 else "green")
        if not nx.is_directed_acyclic_graph(graph):
            raise ValueError(
                f"Constraint graph is not acyclic as required; some constraints involving {[c for c in nx.simple_cycles(graph)]} are invalid."
            )
        return graph

    def _check_constants_are_monotonic(self, longest_path: list[str], subgraph: nx.DiGraph) -> None:
        prev = None
        for node in longest_path:
            if subgraph.nodes[node]["color"] == "red":
                try:
                    new = float(node)
                except ValueError:
                    new = pd.to_datetime(node)
                if prev is not None and prev < new:
                    raise ValueError(
                        f"The constraints are inconsistent, '{prev}' is less than '{new}' but the sequence of constants in a chain of constraints must be monotonically increasing."
                    )
                prev = new

    def _traverse_longest_path(
        self,
        longest_path: list[str],
        subgraph: nx.DiGraph,
        constraints: list[Union[Constraint, ComboConstraint]],
    ) -> None:
        self._check_constants_are_monotonic(longest_path, subgraph)
        for i in range(len(longest_path) - 1):
            item1, item2 = longest_path[i], longest_path[i + 1]
            ref_is_col, operator = True, ">"
            if subgraph.edges[item1, item2]["color"] == "green":
                operator += "="
            if subgraph.nodes[item1]["color"] == "red":  # Note: this breaks if two none col constraints are the same!
                item1, item2 = item2, item1
                ref_is_col, operator = False, operator.replace(">", "<")
            if subgraph.nodes[item2]["color"] == "red":  # Note: this breaks if two none col constraints are the same!
                ref_is_col = False
            constraint = self.Constraint(item1, operator, item2, reference_is_column=ref_is_col)
            if constraint not in constraints:
                constraints.append(constraint)
        return constraints

    def _determine_minimal_subgraph_constraints(
        self,
        subgraph: nx.DiGraph,
        constraints: list[Union[Constraint, ComboConstraint]],
    ) -> None:
        sources = [n for n in subgraph.nodes if subgraph.out_degree(n) == 0 and subgraph.in_degree(n) > 0]
        sinks = [n for n in subgraph.nodes if subgraph.in_degree(n) == 0 and subgraph.out_degree(n) > 0]
        for source in sources:
            for sink in sinks:
                paths = [p for p in nx.all_simple_paths(subgraph, sink, source)]
                paths_nodes = {n for p in paths for n in p}
                while paths_nodes:
                    longest_path = max(paths, key=len)
                    paths_nodes -= set(longest_path)
                    paths = [p for p in paths if not set(p).issubset(set(longest_path))]
                    constraints = self._traverse_longest_path(longest_path, subgraph, constraints)
        return constraints

    def determine_minimal_constraints(self) -> list[Constraint]:
        combo_constraints = []
        constraints = []
        all_subgraphs = [self.graph.subgraph(g) for g in nx.weakly_connected_components(self.graph) if len(g) > 1]
        for subgraph in all_subgraphs:
            if all([subgraph.edges[e]["color"] == "purple" for e in subgraph.edges]):
                combo_constraints.append(self.ComboConstraint(subgraph.nodes))
            else:
                constraints = self._determine_minimal_subgraph_constraints(subgraph, constraints)
        return combo_constraints + constraints

    def _output_graphs_html(self, name: str) -> None:
        if not hasattr(self, "graph") or not hasattr(self, "minimal_graph"):
            raise ValueError("Constraint graphs have not been built yet.")
        net = Network(directed=True, notebook=False, height="100%", width="100%")
        net.from_nx(self.graph)
        html = net.generate_html(notebook=False)
        with open(name, "w") as f:
            f.write(html)
        net = Network(directed=True, notebook=False, height="100%", width="100%")
        net.from_nx(self.minimal_graph)
        html = net.generate_html(notebook=False)
        with open(str(name).replace(".html", "_minimal.html"), "w") as f:
            f.write(html)

    def __iter__(self):
        """
        Make the ConstraintGraph iterable over the minimal_constraints.
        """
        # Assuming self.minimal_constraints is a list of Constraint or ComboConstraint objects
        return iter(self.minimal_constraints)

__iter__()

Make the ConstraintGraph iterable over the minimal_constraints.

Source code in src/nhssynth/modules/dataloader/constraints.py
def __iter__(self):
    """
    Make the ConstraintGraph iterable over the minimal_constraints.
    """
    # Assuming self.minimal_constraints is a list of Constraint or ComboConstraint objects
    return iter(self.minimal_constraints)