-
Notifications
You must be signed in to change notification settings - Fork 6
SURE - Discovering causal structure #385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Jake248Newman
wants to merge
8
commits into
CITCOM-project:main
Choose a base branch
from
Jake248Newman:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
aaf5505
Initial integration of hill climbing
Jake248Newman 1661dcf
Discovery entry point
Jake248Newman 078e93c
Update causal_testing/__main__.py
Jake248Newman 6f9626b
rustworkx for deterministic remove_cycles
Jake248Newman 638c656
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramework
Jake248Newman cb0f5e6
Update causal_testing/main.py
Jake248Newman 8538083
Initial testing of discovery and CLI
Jake248Newman 05c9e9e
Score and tier based fitness
Jake248Newman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| """ | ||
| This module implements a hill climbing algorithm to optimise causal DAGs based on the tests that pass/fail. | ||
| """ | ||
|
|
||
| import json | ||
| import os | ||
| import random | ||
| import sys | ||
| import time | ||
| import warnings | ||
| from collections import Counter | ||
| from itertools import permutations | ||
|
|
||
| import networkx as nx | ||
| import rustworkx as rx | ||
| import numpy as np | ||
| import pandas as pd | ||
| from causal_testing.main import CausalTestingFramework | ||
| from causal_testing.specification.causal_dag import CausalDAG | ||
| from causal_testing.specification.scenario import Scenario | ||
| from causal_testing.testing.causal_test_result import CausalTestResult | ||
| from causal_testing.testing.metamorphic_relation import generate_metamorphic_relations | ||
|
|
||
| warnings.simplefilter("ignore") | ||
|
|
||
| # lexicographical order (max pass, minimise failure, minimise unknown) | ||
| # e.g. (X pass, Y fail, Z+1 unknown) is better than (X pass, Y+1 fail, Z unknown) | ||
|
|
||
| def remove_cycles(causal_dag: CausalDAG, included_edges: set[tuple[str, str]] = set()): | ||
| """ | ||
| Remove cycles from individuals by iteratively deleting a random edge from each cycle until there are no more cycles. | ||
| JN Note: This may be updated using one of the methods discussed (i.e. rustworkx) | ||
|
|
||
| :param causal_dag: The CausalDAG to be repaired. | ||
| :param included_edges: A set of edges that must be included in the repaired DAG. | ||
| """ | ||
| nodes = causal_dag.nodes | ||
| cycle = next(nx.algorithms.simple_cycles(causal_dag), False) | ||
| while cycle: | ||
| inx1 = random.choice(range(len(cycle))) | ||
| inx2 = (inx1 + 1) % len(cycle) | ||
| while (cycle[inx1], cycle[inx2]) in included_edges: | ||
| inx1 = inx2 | ||
| inx2 = (inx1 + 1) % len(cycle) | ||
| causal_dag.remove_edge(cycle[inx1], cycle[inx2]) | ||
| cycle = next(nx.algorithms.simple_cycles(causal_dag), False) | ||
| causal_dag.add_nodes_from(nodes) | ||
|
|
||
|
|
||
| def estimated_effect(result: CausalTestResult, treatment_variable: str): | ||
| """ | ||
| Check whether the estimated causal effect is negative or positive. | ||
| :param result: The causal test result object. | ||
| :param treatment_variable: The name of the treatment variable of the causal test. | ||
| :returns: Whether the estimated causal test is positive or negative (or no effect). | ||
| """ | ||
| if result.ci_low[treatment_variable] < result.ci_high[treatment_variable] < 1: | ||
| return "negative" | ||
| if 1 < result.ci_low[treatment_variable] < result.ci_high[treatment_variable]: | ||
| return "positive" | ||
| return "no effect" | ||
|
|
||
|
|
||
| def evaluate_tests(causal_dag: CausalDAG, df: pd.DataFrame) -> list[tuple[str, str]]: | ||
| """ | ||
| Generate and evaluate causal test cases from the supplied CausalDAG and return a list of edges for which the | ||
| corresponding causal test case failed. | ||
|
|
||
| :param causal_dag: The CausalDAG to evaluate. | ||
| :param df: The data with which to evaluate the causal test cases. | ||
| """ | ||
|
|
||
| ctf = CausalTestingFramework(None) | ||
| ctf.dag = causal_dag | ||
| ctf.data = df | ||
| ctf.create_variables() | ||
| ctf.scenario = Scenario(list(ctf.variables["inputs"].values()) + list(ctf.variables["outputs"].values())) | ||
| ctf.test_cases = ctf.create_test_cases( | ||
| { | ||
| "tests": [ | ||
| relation.to_json_stub( | ||
| estimator="LogisticRegressionEstimator", estimate_type="unit_odds_ratio", alpha=0.01 | ||
| ) | ||
| for relation in generate_metamorphic_relations(causal_dag) | ||
| ] | ||
| } | ||
| ) | ||
| results = [] | ||
|
|
||
| for test_case, result in zip(ctf.test_cases, ctf.run_tests(silent=True)): | ||
| if result.effect_estimate is None: | ||
| results.append( | ||
| { | ||
| "result": "error", | ||
| "expected_effect": test_case.expected_causal_effect.__class__.__name__, | ||
| "treatment": test_case.base_test_case.treatment_variable.name, | ||
| "outcome": test_case.base_test_case.outcome_variable.name, | ||
| } | ||
| ) | ||
| else: | ||
| results.append( | ||
| { | ||
| "result": "pass" if test_case.expected_causal_effect.apply(result) else "failure", | ||
| "expected_effect": test_case.expected_causal_effect.__class__.__name__, | ||
| "treatment": test_case.base_test_case.treatment_variable.name, | ||
| "outcome": test_case.base_test_case.outcome_variable.name, | ||
| "effect": estimated_effect( | ||
| result.effect_estimate, test_case.base_test_case.treatment_variable.name | ||
| ), | ||
| } | ||
| ) | ||
|
|
||
| return pd.DataFrame(results) | ||
|
|
||
|
|
||
| # TODO: Double check whether this method is actually necessary. | ||
| def normalised_counts(test_results: pd.DataFrame) -> dict: | ||
| """ | ||
| Normalise the absolute numbers of pass/fail/error test outcomes. | ||
| MF Note 2026-06-15: I can't actually remember what this method was supposed to do. I need to double check it. | ||
| :param test_results: Dataframe containing the raw pass/fail/error outcome of each test case. | ||
| :returns: Dictionary containing the number of pass/fail/error outcomes, normalised by dividing by the total number | ||
| of each. | ||
| """ | ||
| counts = pd.concat( | ||
| [ | ||
| pd.DataFrame(np.sort(test_results[["treatment", "outcome"]], axis=1), columns=["treatment", "outcome"]), | ||
| pd.get_dummies(test_results["result"]).astype(int), | ||
| ], | ||
| axis=1, | ||
| ) | ||
| for col in ["pass", "failure", "error"]: | ||
| if col not in counts.columns: | ||
| counts[col] = 0 | ||
| counts = counts.groupby(["treatment", "outcome"]).sum().reset_index()[["pass", "failure", "error"]] | ||
| counts = counts.apply(lambda col: col / counts.sum(axis=1)) | ||
| return counts.sum(axis=0).to_dict() | ||
|
|
||
|
|
||
| def evaluate_fitness( | ||
| individual: CausalDAG, df: pd.DataFrame | ||
| ) -> tuple[tuple[float, float, float], list[tuple[str, str]]]: | ||
| """ | ||
| Evaluate the fitness of a given causal DAG by evaluating the corresponding test cases. | ||
| :param individual: The candidate individual to evaluate. | ||
| :param df: The data with which to evaluate the causal tests. | ||
| :returns: Tuple of the form (X, Y), where X is a triple containing the number of passing, inestimable, and failing | ||
| tests respectively, and Y is a list of failing edges. | ||
| """ | ||
| test_results = evaluate_tests(individual, df) | ||
| counts = normalised_counts(test_results) | ||
| problem_tests = test_results.query("result != 'pass'") | ||
| problem_edges = problem_tests[["treatment", "outcome"]].apply(tuple, axis=1).tolist() | ||
| problem_edges.extend( | ||
| problem_tests.query("expected_effect == 'NoEffect'")[["outcome", "treatment"]].apply(tuple, axis=1).tolist() | ||
| ) | ||
|
|
||
| return (counts.get("pass", 0), counts.get("error", 0), counts.get("failure", 0)), problem_edges | ||
|
|
||
|
|
||
| def evolve_dag( | ||
| df: pd.DataFrame, | ||
| random_seed: int = 0, | ||
| output_file: str = None, | ||
| include_edges_file: str = None, | ||
| exclude_edges_file: str = None, | ||
| ) -> CausalDAG: | ||
| """ | ||
| Evolve a causal DAG for a given dataset. | ||
| :param df: The data for which to fit a causal DAG. | ||
| :param random_seed: The random seed to use for genetic computation. | ||
| :param output_file: Where to save the inferred causal DAG (if supplied). | ||
| :param include_edges_file: Path to file containing edges to include. | ||
| :param exclude_edges_file: Path to file containing edges to exclude. | ||
| :returns: The inferred causal DAG. | ||
| """ | ||
| random.seed(random_seed) | ||
|
|
||
| included_edges = set(nx.nx_pydot.read_dot(include_edges_file).edges()) if include_edges_file is not None else set() | ||
| excluded_edges = set(nx.nx_pydot.read_dot(exclude_edges_file).edges()) if exclude_edges_file is not None else set() | ||
| possible_edges = sorted(list((u, v) for u, v in permutations(df.columns, 2) if (u, v) not in excluded_edges)) | ||
|
|
||
| start_time = time.time() | ||
| individual = CausalDAG() | ||
| individual.add_nodes_from(df.columns) | ||
| individual.add_edges_from(possible_edges) | ||
| remove_cycles(individual, included_edges) | ||
| fitness_values, problem_edges = evaluate_fitness(individual, df) | ||
|
|
||
| iterations = 100 | ||
| iterations_without_improvement = 0 | ||
|
|
||
| while problem_edges and iterations: | ||
| iterations -= 1 | ||
| print(iterations, fitness_values, iterations_without_improvement) | ||
|
|
||
| new_individual = individual.copy() | ||
| for origin, dest in random.sample( | ||
| problem_edges + (possible_edges if iterations_without_improvement > 10 else []), | ||
| random.randint(1, len(problem_edges)), | ||
| ): | ||
| if new_individual.has_edge(origin, dest) and (origin, dest) not in included_edges: | ||
| new_individual.remove_edge(origin, dest) | ||
| elif not new_individual.has_edge(origin, dest) and (origin, dest) not in excluded_edges: | ||
| # Want to bypass the cycle check of CausalDAG as we remove the cycles afterwards | ||
| super(CausalDAG, new_individual).add_edge(origin, dest) | ||
| remove_cycles(new_individual, included_edges) | ||
| new_fitness_values, new_problem_edges = evaluate_fitness(new_individual, df) | ||
| # assert sum(new_fitness_values) == sum(fitness_values) | ||
| print(" ", new_fitness_values) | ||
|
|
||
| if new_fitness_values > fitness_values: | ||
| fitness_values = new_fitness_values | ||
| problem_edges = new_problem_edges | ||
| individual = new_individual | ||
| iterations_without_improvement = 0 | ||
| else: | ||
| iterations_without_improvement += 1 | ||
|
|
||
| end_time = time.time() | ||
| individual.graph["fitness"] = fitness_values | ||
| individual.graph["time"] = round(end_time - start_time) | ||
| if output_file is not None: | ||
| nx.drawing.nx_agraph.write_dot( | ||
| individual, | ||
| output_file, | ||
| ) | ||
| return individual | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
Jake248Newman marked this conversation as resolved.
Outdated
|
||
| evolve_dag( | ||
| df=pd.read_csv("test_data.csv"), | ||
| random_seed=1, | ||
| output_file=f"tmp/test_dag.dot", | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.