Skip to content

Commit

Permalink
feat(core): add a step execution mode "apply_synchronous" which synch…
Browse files Browse the repository at this point in the history
…ronously applies step results before the next step is processed

- All pipeline steps which are not parallel executed can receive apply_synchronous=True in their pipeline step definition which ensures that before the next test in this step is executed the graph object gets applied.
- This is disabled by default
- example can be found in tests/tes_synchronous_graphs.py
  • Loading branch information
LilithWittmann committed Feb 5, 2025
1 parent ca02309 commit 74a922c
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 10 deletions.
28 changes: 23 additions & 5 deletions causy/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,14 @@ def execute_pipeline_step(
else:
# this is the only mode which supports unapplied actions to be passed to the next pipeline step (for now)
# which are sometimes needed for e.g. conflict resolution

is_synchronous = False

if hasattr(test_fn, "apply_synchronous"):
# ensure that the graph gets changes applied synchronously - so before the next element is executed
if test_fn.apply_synchronous:
is_synchronous = True

iterator = [
i
for i in [
Expand All @@ -499,11 +507,21 @@ def execute_pipeline_step(
if rn_fn.needs_unapplied_actions:
i.append(local_results)
local_results.append(unpack_run(i))
actions_taken_current, all_actions_current = self._take_action(
local_results, dry_run=not apply_to_graph
)
actions_taken.extend(actions_taken_current)
all_actions.extend(all_actions_current)

if is_synchronous:
actions_taken_current, all_actions_current = self._take_action(
local_results, dry_run=not apply_to_graph
)
actions_taken.extend(actions_taken_current)
all_actions.extend(all_actions_current)
local_results = []

if not is_synchronous:
actions_taken_current, all_actions_current = self._take_action(
local_results, dry_run=not apply_to_graph
)
actions_taken.extend(actions_taken_current)
all_actions.extend(all_actions_current)

return actions_taken, all_actions

Expand Down
21 changes: 16 additions & 5 deletions causy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,19 +343,27 @@ def name(self) -> str:

class PipelineStepInterface(ABC, BaseModel, Generic[PipelineStepInterfaceType]):
generator: Optional[GeneratorInterface] = None
threshold: Optional[FloatParameter] = DEFAULT_THRESHOLD
chunk_size_parallel_processing: IntegerParameter = 1
parallel: BoolParameter = True
threshold: Optional[FloatParameter] = DEFAULT_THRESHOLD # threshold for the test
chunk_size_parallel_processing: IntegerParameter = (
1 # chunk size for parallel processing
)
parallel: BoolParameter = True # if True, the pipeline step will be executed in parallel (only works non synchronous)

display_name: Optional[StringParameter] = None
display_name: Optional[StringParameter] = None # display name of the pipeline step

needs_unapplied_actions: Optional[BoolParameter] = False
needs_unapplied_actions: Optional[
BoolParameter
] = False # if True, the pipeline step needs unapplied actions to be passed to it
apply_synchronous: Optional[
BoolParameter
] = False # if True, the result of the pipeline step will be applied synchronously (only works non chunked and non parallel)

def __init__(
self,
threshold: Optional[FloatParameter] = None,
generator: Optional[GeneratorInterface] = None,
chunk_size_parallel_processing: Optional[IntegerParameter] = None,
apply_synchronous: Optional[BoolParameter] = None,
parallel: Optional[BoolParameter] = None,
display_name: Optional[StringParameter] = None,
**kwargs,
Expand All @@ -370,6 +378,9 @@ def __init__(
if chunk_size_parallel_processing:
self.chunk_size_parallel_processing = chunk_size_parallel_processing

if apply_synchronous:
self.apply_synchronous = apply_synchronous

if parallel:
self.parallel = parallel

Expand Down
80 changes: 80 additions & 0 deletions tests/test_synchronous_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from causy.causal_discovery.constraint.algorithms.pc import (
PC_ORIENTATION_RULES,
PC_EDGE_TYPES,
PC_GRAPH_UI_EXTENSION,
PC_DEFAULT_THRESHOLD,
)
from causy.causal_discovery.constraint.independence_tests.common import (
CorrelationCoefficientTest,
PartialCorrelationTest,
ExtendedPartialCorrelationTestMatrix,
)
from causy.causal_effect_estimation.multivariate_regression import (
ComputeDirectEffectsInDAGsMultivariateRegression,
)
from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations
from causy.graph_model import graph_model_factory
from causy.models import Algorithm
from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference
from causy.variables import VariableReference, FloatVariable
from tests.utils import CausyTestCase


class PCTestTestCase(CausyTestCase):
SEED = 1

def _sample_generator(self):
rdnv = self.seeded_random.normalvariate
return IIDSampleGenerator(
edges=[
SampleEdge(NodeReference("X"), NodeReference("Y"), 5),
SampleEdge(NodeReference("X"), NodeReference("Z"), 8),
SampleEdge(NodeReference("X"), NodeReference("W"), 4),
],
random=lambda: rdnv(0, 1),
)

SYNCHRONOUS_PC = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
apply_synchronous=True,
),
PartialCorrelationTest(
threshold=VariableReference(name="threshold"),
display_name="Partial Correlation Test",
apply_synchronous=True,
),
ExtendedPartialCorrelationTestMatrix(
threshold=VariableReference(name="threshold"),
display_name="Extended Partial Correlation Test Matrix",
apply_synchronous=True,
),
*PC_ORIENTATION_RULES,
ComputeDirectEffectsInDAGsMultivariateRegression(
display_name="Compute Direct Effects in DAGs Multivariate Regression"
),
],
edge_types=PC_EDGE_TYPES,
extensions=[PC_GRAPH_UI_EXTENSION],
name="PC",
variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)],
)
)

def test_execute_pipeline(self):
model = self._sample_generator()
data, graph = model.generate(100)

pc = self.SYNCHRONOUS_PC()
pc.create_graph_from_data(data)
pc.create_graph_from_data(data)
pc.create_all_possible_edges()
pc.execute_pipeline_steps()

self.assertGraphStructureIsEqual(pc.graph, graph)

0 comments on commit 74a922c

Please sign in to comment.