Skip to content

Commit

Permalink
Merge pull request #74 from causy-dev/pc-classic
Browse files Browse the repository at this point in the history
feat(causal_discovery_algorithms): add PC classic without runtime opt…
  • Loading branch information
this-is-sofia authored Jan 29, 2025
2 parents 743570d + 229de5d commit 0dc0f33
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 65 deletions.
32 changes: 32 additions & 0 deletions causy/causal_discovery/constraint/algorithms/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@
)
)

PCClassic = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
),
ExtendedPartialCorrelationTestMatrix(
threshold=VariableReference(name="threshold"),
display_name="Extended Partial Correlation Test Matrix",
generator=PairsWithNeighboursGenerator(
comparison_settings=ComparisonSettings(
min=3, max=AS_MANY_AS_FIELDS
),
shuffle_combinations=False,
),
),
*PC_ORIENTATION_RULES,
ComputeDirectEffectsInDAGsMultivariateRegression(
display_name="Compute Direct Effects in DAGs Multivariate Regression"
),
],
edge_types=PC_EDGE_TYPES,
extensions=[PC_GRAPH_UI_EXTENSION],
name="PC",
variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)],
)
)

PCStable = graph_model_factory(
Algorithm(
pipeline_steps=[
Expand Down
68 changes: 3 additions & 65 deletions tests/test_pc_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PC,
PC_ORIENTATION_RULES,
PC_GRAPH_UI_EXTENSION,
PC_DEFAULT_THRESHOLD,
PC_DEFAULT_THRESHOLD, PCClassic,
)
from causy.causal_effect_estimation.multivariate_regression import (
ComputeDirectEffectsInDAGsMultivariateRegression,
Expand Down Expand Up @@ -351,37 +351,6 @@ def test_tracking_triples_four_nodes(self):
self.assertEqual(len(triples), 6 + 12 + 12)

def test_track_triples_three_nodes_custom_pc(self):
algo = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
),
ExtendedPartialCorrelationTestMatrix(
threshold=VariableReference(name="threshold"),
display_name="Extended Partial Correlation Test Matrix",
generator=PairsWithNeighboursGenerator(
comparison_settings=ComparisonSettings(
min=3, max=AS_MANY_AS_FIELDS
),
shuffle_combinations=False,
),
),
*PC_ORIENTATION_RULES,
ComputeDirectEffectsInDAGsMultivariateRegression(
display_name="Compute Direct Effects in DAGs Multivariate Regression"
),
],
edge_types=PC_EDGE_TYPES,
extensions=[PC_GRAPH_UI_EXTENSION],
name="PC",
variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)],
)
)
rdnv = self.seeded_random.normalvariate
sample_generator = IIDSampleGenerator(
edges=[
Expand All @@ -391,7 +360,7 @@ def test_track_triples_three_nodes_custom_pc(self):
random=lambda: rdnv(0, 1),
)
test_data, graph = sample_generator.generate(10000)
tst = algo()
tst = PCClassic()
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
pc_results = tst.execute_pipeline_steps()
Expand All @@ -407,37 +376,6 @@ def test_track_triples_three_nodes_custom_pc(self):
self.assertIn(len(triples), [6, 7, 8])

def test_track_triples_two_nodes_custom_pc_unconditionally_independent(self):
algo = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
),
ExtendedPartialCorrelationTestMatrix(
threshold=VariableReference(name="threshold"),
display_name="Extended Partial Correlation Test Matrix",
generator=PairsWithNeighboursGenerator(
comparison_settings=ComparisonSettings(
min=3, max=AS_MANY_AS_FIELDS
),
shuffle_combinations=False,
),
),
*PC_ORIENTATION_RULES,
ComputeDirectEffectsInDAGsMultivariateRegression(
display_name="Compute Direct Effects in DAGs Multivariate Regression"
),
],
edge_types=PC_EDGE_TYPES,
extensions=[PC_GRAPH_UI_EXTENSION],
name="PC",
variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)],
)
)
rdnv = self.seeded_random.normalvariate
sample_generator = IIDSampleGenerator(
edges=[
Expand All @@ -447,7 +385,7 @@ def test_track_triples_two_nodes_custom_pc_unconditionally_independent(self):
random=lambda: rdnv(0, 1),
)
test_data, graph = sample_generator.generate(10000)
tst = algo()
tst = PCClassic()
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
pc_results = tst.execute_pipeline_steps()
Expand Down

0 comments on commit 0dc0f33

Please sign in to comment.