Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions causal_testing/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import logging
import os
import tempfile
import pandas as pd
from pathlib import Path

from causal_testing.testing.metamorphic_relation import generate_causal_tests
from causal_testing.discovery.hill_climber import evolve_dag

from .main import CausalTestingFramework, CausalTestingPaths, Command, parse_args, setup_logging

Expand Down Expand Up @@ -39,6 +41,17 @@ def main() -> None:
# Setup logging
setup_logging(args.verbose)

if args.command == Command.DISCOVER:
logging.info("Discovering causal structures")
Comment thread
Jake248Newman marked this conversation as resolved.
Outdated
evolve_dag(
df=pd.read_csv(args.data_path),
output_file=args.output_dag_path,
include_edges_file=args.include_edges,
exclude_edges_file=args.exclude_edges,
)
logging.info("Causal structure discovery completed successfully")
return

# Create paths object
paths = CausalTestingPaths(
dag_path=args.dag_path,
Expand Down
Empty file.
236 changes: 236 additions & 0 deletions causal_testing/discovery/hill_climber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""
This module implements a hill climbing algorithm to optimise causal DAGs based on the tests that pass/fail.
"""

import json
import os
import random
import sys
import time
import warnings
from collections import Counter
from itertools import permutations

import networkx as nx
import rustworkx as rx
import numpy as np
import pandas as pd
from causal_testing.main import CausalTestingFramework
from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.specification.scenario import Scenario
from causal_testing.testing.causal_test_result import CausalTestResult
from causal_testing.testing.metamorphic_relation import generate_metamorphic_relations

warnings.simplefilter("ignore")

# lexicographical order (max pass, minimise failure, minimise unknown)
# e.g. (X pass, Y fail, Z+1 unknown) is better than (X pass, Y+1 fail, Z unknown)

def remove_cycles(causal_dag: CausalDAG, included_edges: set[tuple[str, str]] = set()):
"""
Remove cycles from individuals by iteratively deleting a random edge from each cycle until there are no more cycles.
JN Note: This may be updated using one of the methods discussed (i.e. rustworkx)

:param causal_dag: The CausalDAG to be repaired.
:param included_edges: A set of edges that must be included in the repaired DAG.
"""
nodes = causal_dag.nodes
cycle = next(nx.algorithms.simple_cycles(causal_dag), False)
while cycle:
inx1 = random.choice(range(len(cycle)))
inx2 = (inx1 + 1) % len(cycle)
while (cycle[inx1], cycle[inx2]) in included_edges:
inx1 = inx2
inx2 = (inx1 + 1) % len(cycle)
causal_dag.remove_edge(cycle[inx1], cycle[inx2])
cycle = next(nx.algorithms.simple_cycles(causal_dag), False)
causal_dag.add_nodes_from(nodes)


def estimated_effect(result: CausalTestResult, treatment_variable: str):
"""
Check whether the estimated causal effect is negative or positive.
:param result: The causal test result object.
:param treatment_variable: The name of the treatment variable of the causal test.
:returns: Whether the estimated causal test is positive or negative (or no effect).
"""
if result.ci_low[treatment_variable] < result.ci_high[treatment_variable] < 1:
return "negative"
if 1 < result.ci_low[treatment_variable] < result.ci_high[treatment_variable]:
return "positive"
return "no effect"


def evaluate_tests(causal_dag: CausalDAG, df: pd.DataFrame) -> list[tuple[str, str]]:
"""
Generate and evaluate causal test cases from the supplied CausalDAG and return a list of edges for which the
corresponding causal test case failed.

:param causal_dag: The CausalDAG to evaluate.
:param df: The data with which to evaluate the causal test cases.
"""

ctf = CausalTestingFramework(None)
ctf.dag = causal_dag
ctf.data = df
ctf.create_variables()
ctf.scenario = Scenario(list(ctf.variables["inputs"].values()) + list(ctf.variables["outputs"].values()))
ctf.test_cases = ctf.create_test_cases(
{
"tests": [
relation.to_json_stub(
estimator="LogisticRegressionEstimator", estimate_type="unit_odds_ratio", alpha=0.01
)
for relation in generate_metamorphic_relations(causal_dag)
]
}
)
results = []

for test_case, result in zip(ctf.test_cases, ctf.run_tests(silent=True)):
if result.effect_estimate is None:
results.append(
{
"result": "error",
"expected_effect": test_case.expected_causal_effect.__class__.__name__,
"treatment": test_case.base_test_case.treatment_variable.name,
"outcome": test_case.base_test_case.outcome_variable.name,
}
)
else:
results.append(
{
"result": "pass" if test_case.expected_causal_effect.apply(result) else "failure",
"expected_effect": test_case.expected_causal_effect.__class__.__name__,
"treatment": test_case.base_test_case.treatment_variable.name,
"outcome": test_case.base_test_case.outcome_variable.name,
"effect": estimated_effect(
result.effect_estimate, test_case.base_test_case.treatment_variable.name
),
}
)

return pd.DataFrame(results)


# TODO: Double check whether this method is actually necessary.
def normalised_counts(test_results: pd.DataFrame) -> dict:
"""
Normalise the absolute numbers of pass/fail/error test outcomes.
MF Note 2026-06-15: I can't actually remember what this method was supposed to do. I need to double check it.
:param test_results: Dataframe containing the raw pass/fail/error outcome of each test case.
:returns: Dictionary containing the number of pass/fail/error outcomes, normalised by dividing by the total number
of each.
"""
counts = pd.concat(
[
pd.DataFrame(np.sort(test_results[["treatment", "outcome"]], axis=1), columns=["treatment", "outcome"]),
pd.get_dummies(test_results["result"]).astype(int),
],
axis=1,
)
for col in ["pass", "failure", "error"]:
if col not in counts.columns:
counts[col] = 0
counts = counts.groupby(["treatment", "outcome"]).sum().reset_index()[["pass", "failure", "error"]]
counts = counts.apply(lambda col: col / counts.sum(axis=1))
return counts.sum(axis=0).to_dict()


def evaluate_fitness(
individual: CausalDAG, df: pd.DataFrame
) -> tuple[tuple[float, float, float], list[tuple[str, str]]]:
"""
Evaluate the fitness of a given causal DAG by evaluating the corresponding test cases.
:param individual: The candidate individual to evaluate.
:param df: The data with which to evaluate the causal tests.
:returns: Tuple of the form (X, Y), where X is a triple containing the number of passing, inestimable, and failing
tests respectively, and Y is a list of failing edges.
"""
test_results = evaluate_tests(individual, df)
counts = normalised_counts(test_results)
problem_tests = test_results.query("result != 'pass'")
problem_edges = problem_tests[["treatment", "outcome"]].apply(tuple, axis=1).tolist()
problem_edges.extend(
problem_tests.query("expected_effect == 'NoEffect'")[["outcome", "treatment"]].apply(tuple, axis=1).tolist()
)

return (counts.get("pass", 0), counts.get("error", 0), counts.get("failure", 0)), problem_edges


def evolve_dag(
df: pd.DataFrame,
random_seed: int = 0,
output_file: str = None,
include_edges_file: str = None,
exclude_edges_file: str = None,
) -> CausalDAG:
"""
Evolve a causal DAG for a given dataset.
:param df: The data for which to fit a causal DAG.
:param random_seed: The random seed to use for genetic computation.
:param output_file: Where to save the inferred causal DAG (if supplied).
:param include_edges_file: Path to file containing edges to include.
:param exclude_edges_file: Path to file containing edges to exclude.
:returns: The inferred causal DAG.
"""
random.seed(random_seed)

included_edges = set(nx.nx_pydot.read_dot(include_edges_file).edges()) if include_edges_file is not None else set()
excluded_edges = set(nx.nx_pydot.read_dot(exclude_edges_file).edges()) if exclude_edges_file is not None else set()
possible_edges = sorted(list((u, v) for u, v in permutations(df.columns, 2) if (u, v) not in excluded_edges))

start_time = time.time()
individual = CausalDAG()
individual.add_nodes_from(df.columns)
individual.add_edges_from(possible_edges)
remove_cycles(individual, included_edges)
fitness_values, problem_edges = evaluate_fitness(individual, df)

iterations = 100
iterations_without_improvement = 0

while problem_edges and iterations:
iterations -= 1
print(iterations, fitness_values, iterations_without_improvement)

new_individual = individual.copy()
for origin, dest in random.sample(
problem_edges + (possible_edges if iterations_without_improvement > 10 else []),
random.randint(1, len(problem_edges)),
):
if new_individual.has_edge(origin, dest) and (origin, dest) not in included_edges:
new_individual.remove_edge(origin, dest)
elif not new_individual.has_edge(origin, dest) and (origin, dest) not in excluded_edges:
# Want to bypass the cycle check of CausalDAG as we remove the cycles afterwards
super(CausalDAG, new_individual).add_edge(origin, dest)
remove_cycles(new_individual, included_edges)
new_fitness_values, new_problem_edges = evaluate_fitness(new_individual, df)
# assert sum(new_fitness_values) == sum(fitness_values)
print(" ", new_fitness_values)

if new_fitness_values > fitness_values:
fitness_values = new_fitness_values
problem_edges = new_problem_edges
individual = new_individual
iterations_without_improvement = 0
else:
iterations_without_improvement += 1

end_time = time.time()
individual.graph["fitness"] = fitness_values
individual.graph["time"] = round(end_time - start_time)
if output_file is not None:
nx.drawing.nx_agraph.write_dot(
individual,
output_file,
)
return individual


if __name__ == "__main__":
Comment thread
Jake248Newman marked this conversation as resolved.
Outdated
evolve_dag(
df=pd.read_csv("test_data.csv"),
random_seed=1,
output_file=f"tmp/test_dag.dot",
)
34 changes: 34 additions & 0 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Command(Enum):

TEST = "test"
GENERATE = "generate"
DISCOVER = "discover"


@dataclass
Expand All @@ -48,18 +49,31 @@ class CausalTestingPaths:
data_paths: List[Path]
test_config_path: Path
output_path: Path
discovery_data_path: Path
discovery_output_dag_path: Path
discovery_include_edges_path: Optional[Path] = None
discovery_exclude_edges_path: Optional[Path] = None

def __init__(
self,
dag_path: Union[str, Path],
data_paths: List[Union[str, Path]],
test_config_path: Union[str, Path],
output_path: Union[str, Path],
discovery_data_path: Union[str, Path],
output_dag_path: Union[str, Path],
include_edges_path: Optional[Union[str, Path]] = None,
exclude_edges_path: Optional[Union[str, Path]] = None,
):
self.dag_path = Path(dag_path)
self.data_paths = [Path(p) for p in data_paths]
self.test_config_path = Path(test_config_path)
self.output_path = Path(output_path)
self.discovery_data_path = Path(discovery_data_path)
self.output_dag_path = Path(output_dag_path)
self.include_edges_path = Path(include_edges_path) if include_edges_path else None
self.exclude_edges_path = Path(exclude_edges_path) if exclude_edges_path else None


def validate_paths(self) -> None:
"""
Expand All @@ -81,6 +95,18 @@ def validate_paths(self) -> None:
if not self.output_path.parent.exists():
self.output_path.parent.mkdir(parents=True)

if not self.discovery_data_path.exists():
raise FileNotFoundError(f"Data file not found: {self.discovery_data_path}")

if not self.output_dag_path.parent.exists():
self.output_dag_path.parent.mkdir(parents=True)

if self.include_edges_path and not self.include_edges_path.exists():
raise FileNotFoundError(f"Data file not found: {self.include_edges_path}")

if self.exclude_edges_path and not self.exclude_edges_path.exists():
raise FileNotFoundError(f"Data file not found: {self.exclude_edges_path}")


class CausalTestingFramework:
# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -567,6 +593,14 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
help="Run tests in batches of the specified size (default: 0, which means no batching)",
)

#Discovery
parser_discover = subparsers.add_parser(Command.DISCOVER.value, help="Discover causal structures from data")
parser_discover.add_argument("-d", "--data-path", help="Path to data file (.csv)", required=True)
Comment thread
Jake248Newman marked this conversation as resolved.
Outdated
parser_discover.add_argument("-o", "--output-dag-path", help="Path for output DAG file (.dot)", required=True)
parser_discover.add_argument("-i", "--include-edges", help="Path to file containing edges to include", required=False)
parser_discover.add_argument("-e", "--exclude-edges", help="Path to file containing edges to exclude", required=False)
parser_discover.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)

args = main_parser.parse_args(args)

# Assume the user wants test adequacy if they're setting bootstrap_size
Expand Down