Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Generalized Adjustment Criterion #1292

Merged
merged 17 commits into from
Jan 21, 2025
Prev Previous commit
Next Next commit
adding default case
Signed-off-by: Nicholas Parente <[email protected]>
nparent1 committed Dec 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 5f3bc5bb7662dddc11997f53983cb41054d04539
2 changes: 1 addition & 1 deletion dowhy/causal_identifier/adjustment_set.py
Original file line number Diff line number Diff line change
@@ -23,5 +23,5 @@ def get_variables(self):
return self.variables

def get_num_paths_blocked_by_observed_nodes(self):
"""Return the number of paths blocked by the observed nodes (optional)"""
"""Return the number of paths blocked by observed nodes (optional)"""
return self.num_paths_blocked_by_observed_nodes
22 changes: 21 additions & 1 deletion dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
import logging
from enum import Enum
from typing import Dict, List, Optional, Union
import copy

import networkx as nx
import sympy as sp
@@ -21,6 +22,8 @@
get_descendants,
get_instruments,
has_directed_path,
get_proper_causal_path_nodes,
get_proper_backdoor_graph
)
from dowhy.utils.api import parse_state

@@ -884,7 +887,24 @@ def identify_complete_adjustment_set(
observed_nodes: List[str],
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT
) -> List[AdjustmentSet]:
# TODO: Implement this. Must return a list of AdjustmentSet objects.

graph_pbd = get_proper_backdoor_graph(graph, action_nodes, outcome_nodes)
pcp_nodes = get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)

if covariate_adjustment == CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT:
# In default case, we don't find all exhaustive adjustment sets
adjustment_set = nx.algorithms.find_minimal_d_separator(
graph_pbd,
action_nodes,
outcome_nodes,
# Require the adjustment set to consist only of observed nodes
restricted=((set(graph.nodes) - set(pcp_nodes)) & set(observed_nodes))
)
if adjustment_set is None:
logger.info("No adjustment sets found.")
return []
return [AdjustmentSet(AdjustmentSet.GENERAL, adjustment_set)]

return [AdjustmentSet(AdjustmentSet.GENERAL, [])]


41 changes: 41 additions & 0 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import re
from abc import abstractmethod
from typing import Any, List, Protocol
import copy

import networkx as nx
from networkx.algorithms.dag import has_cycle
@@ -187,13 +188,53 @@ def is_blocked(graph: nx.DiGraph, path, conditioned_nodes):
return False


def get_ancestors(graph: nx.DiGraph, nodes):
ancestors = set()
for node_name in nodes:
ancestors = ancestors.union(set(nx.ancestors(graph, node_name)))
return ancestors


def get_descendants(graph: nx.DiGraph, nodes):
descendants = set()
for node_name in nodes:
descendants = descendants.union(set(nx.descendants(graph, node_name)))
return descendants


def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes):
# Process is described in Van Der Zander et al. "Constructing Separators and
# Adjustment Sets in Ancestral Graphs", Section 4.1

# 1) Create modified graphs removing inbound and outbound arrows from the action nodes, respectively.
graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes
edges_to_remove = [(u, v) for u, v in graph_post_interv.in_edges(action_nodes)]
graph_post_interv.remove_edges_from(edges_to_remove)
graph_with_action_nodes_as_sinks = copy.deepcopy(graph) # remove outbound arrows from our action nodes
edges_to_remove = [(u, v) for u, v in graph_with_action_nodes_as_sinks.out_edges(action_nodes)]
graph_with_action_nodes_as_sinks.remove_edges_from(edges_to_remove)

# 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the
# action nodes to the outcome nodes.
de_x = get_descendants(graph_post_interv, action_nodes)
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes)
return (set(de_x) - set(action_nodes)) & an_y


def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
# Process is described in Van Der Zander et al. "Constructing Separators and
# Adjustment Sets in Ancestral Graphs", Section 4.1

# First we can just call get_proper_causal_path_nodes, then
# we remove edges from the action_nodes to the proper causal path nodes
graph_pbd = copy.deepcopy(graph)
graph_pbd.remove_edges_from(
[(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)]
)
return graph_pbd



def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"):
if dseparation_algo == "default":
if new_graph is None: