Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support Generalized Adjustment Criterion for Estimation plus Add Example Notebook #1297

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes to true ate estimation
Signed-off-by: Nicholas Parente <[email protected]>
nparent1 committed Jan 20, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit cd3ba3879272afa1633bd7c7e48cd32b9863280c
58 changes: 43 additions & 15 deletions dowhy/datasets.py
Original file line number Diff line number Diff line change
@@ -626,6 +626,20 @@ def check_all_node_types_are_specified(graph, variable_type_dict):
raise ValueError(f"Graph node '{node}' is not present in variable_type_dict keys")


def ate_from_direct_graph_weights(graph, treatments, outcome):
# Only valid if the treatments are the only non-continuous variables in the graph
ate = 0
for treatment in treatments:
for path in nx.all_simple_paths(graph, source=treatment, target=outcome):
path_multiplier = 1
if set(treatments).intersection(path[1:-1]):
continue
for a, b in zip(path, path[1:]):
path_multiplier *= graph[a][b]["weight"]
ate += path_multiplier
return ate


def linear_dataset_from_graph(
graph,
treatments,
@@ -673,6 +687,7 @@ def linear_dataset_from_graph(
for node in all_nodes:
changed[node] = False
df = pd.DataFrame()
# Creating datasets to store the intervention data, for estimating the true ATE
currset = list()
counter = 0

@@ -696,6 +711,9 @@ def linear_dataset_from_graph(
currset.extend(successors)
changed[node] = True

df_treated = df.copy(deep=True)
df_untreated = df.copy(deep=True)

# "currset" variable currently has all the successors of the nodes which had no incoming edges
while len(currset) > 0:
cs = list() # Variable to store immediate children of nodes present in "currset"
@@ -710,36 +728,46 @@ def linear_dataset_from_graph(
successors.sort()
cs.extend(successors) # Storing immediate children for next level data generation

X = df[predecessors].to_numpy() # Using parent nodes data
treatment_indices = [i for i, col in enumerate(predecessors) if col in treatments]
X_observed = df[predecessors].to_numpy() # Using parent nodes data
X_treated = df_treated[predecessors].to_numpy()
X_treated[:, treatment_indices] = 1
X_untreated = df_untreated[predecessors].to_numpy()
X_untreated[:, treatment_indices] = 0
c = np.array([graph[u][node]["weight"] for u in predecessors])
t = np.random.normal(0, 1, num_samples) + X @ c # Using Linear Regression to generate data
t_observed = np.random.normal(0, 1, num_samples) + X_observed @ c # Using Linear Regression to generate data
t_treated = np.random.normal(0, 1, num_samples) + X_treated @ c # Using Linear Regression to generate data
t_untreated = np.random.normal(0, 1, num_samples) + X_untreated @ c # Using Linear Regression to generate data

changed[node] = True
dtype = variable_type_dict[node]
counter += 1
if dtype == DISCRETE:
df[node] = convert_continuous_to_discrete(t)
df[node] = convert_continuous_to_discrete(t_observed)
df_treated[node] = convert_continuous_to_discrete(t_treated)
df_untreated[node] = convert_continuous_to_discrete(t_untreated)
discrete_cols.append(node)
elif dtype == CONTINUOUS:
df[node] = t
df[node] = t_observed
df_treated[node] = t_treated
df_untreated[node] = t_untreated
continuous_cols.append(node)
else:
nums = np.random.normal(0, 1, num_samples)
df[node] = np.vectorize(convert_to_binary)(nums)
# nums = np.random.normal(0, 1, num_samples)
df[node] = np.vectorize(convert_to_binary)(t_observed)
df_treated[node] = np.vectorize(convert_to_binary)(t_treated)
df_untreated[node] = np.vectorize(convert_to_binary)(t_untreated)
discrete_cols.append(node)
binary_cols.append(node)
currset = cs

# Compute ATE:
ate = 0
for treatment in treatments:
for path in nx.all_simple_paths(graph, source=treatment, target=outcome):
path_multiplier = 1
if set(treatments).intersection(path[1:-1]):
continue
for a, b in zip(path, path[1:]):
path_multiplier *= graph[a][b]["weight"]
ate += path_multiplier
# If all non-treatment variables are continuous, then can be computed from the
# graph directly
if all(variable_type_dict[x] == CONTINUOUS for x in variable_type_dict.keys() if x not in treatments):
ate = ate_from_direct_graph_weights(graph, treatments, outcome)
else:
ate = np.mean(df_treated[outcome]) - np.mean(df_untreated[outcome])

gml_str = "\n".join(nx.generate_gml(graph))
ret_dict = {
6 changes: 3 additions & 3 deletions tests/causal_estimators/base.py
Original file line number Diff line number Diff line change
@@ -162,7 +162,7 @@ def average_treatment_effect_testsuite(
cfg["method_params"] = method_params
self.average_treatment_effect_test(**cfg)

def custom_data_average_treatment_effect_test(self, data, method_params=None):
def custom_data_average_treatment_effect_test(self, data, method_params={}):
target_estimand = identify_effect_auto(
build_graph_from_str(data["gml_graph"]),
observed_nodes=list(data["df"].columns),
@@ -175,13 +175,13 @@ def custom_data_average_treatment_effect_test(self, data, method_params=None):
estimator_ate.fit(data["df"])
true_ate = data["ate"]
ate_estimate = estimator_ate.estimate_effect(data["df"])
error = ate_estimate.value - true_ate
error = abs(ate_estimate.value - true_ate)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor bug fix here - taking the absolute value, both here and a few lines down.

Previously the test util would return true if (error < true_ate * self._error_tolerance) which only makes sense if the error and true_ate are in absolute value (otherwise a negative error would always be accepted no matter how far from zero)

print(
"Error in ATE estimate = {0} with tolerance {1}%. Estimated={2},True={3}".format(
error, self._error_tolerance * 100, ate_estimate.value, true_ate
)
)
res = True if (error < true_ate * self._error_tolerance) else False
res = True if (error < abs(true_ate) * self._error_tolerance) else False
assert res


Original file line number Diff line number Diff line change
@@ -104,13 +104,13 @@ def test_average_treatment_effect(
},
)

def test_general_adjustment_specific_graphs(self, example_graph: TestGraphObject):
def test_general_adjustment_estimation_on_example_graphs(self, example_graph: TestGraphObject):
data = dowhy.datasets.linear_dataset_from_graph(
example_graph.graph,
example_graph.action_nodes,
example_graph.outcome_node,
treatments_are_binary=True,
outcome_is_binary=False,
outcome_is_binary=True,
num_samples=50000,
)
data["df"] = data["df"][example_graph.observed_nodes]