Skip to content

Commit

Permalink
refactor: automatically patching of models
Browse files Browse the repository at this point in the history
  • Loading branch information
arjendev committed Nov 9, 2023
1 parent 314a1f4 commit bbc159a
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,8 @@
class Activity:
status: DependencyCondition

@staticmethod
def patch_generated_models(models):
models.Activity._evaluate_expressions = Activity.evaluate_expressions
models.Activity.evaluate = Activity.evaluate
models.Activity.are_dependency_condition_met = Activity.are_dependency_condition_met
models.Activity.get_scoped_activity_result_by_name = Activity.get_scoped_activity_result_by_name
models.Activity.status = None

def evaluate(self, state: PipelineRunState) -> Activity:
self._evaluate_expressions(self)
self.evaluate_expressions(self)
self.status = DependencyCondition.Succeeded
return self

Expand All @@ -34,7 +26,7 @@ def evaluate_expressions(self, obj: Any, visited: List[Any] = None):
if data_factory_element := isinstance(attribute, DataFactoryElement) and attribute:
data_factory_element.evaluate()
else:
self._evaluate_expressions(attribute, visited)
self.evaluate_expressions(attribute, visited)

def get_scoped_activity_result_by_name(self, name: str, state: PipelineRunState):
return next((activity_result for activity_result in state.scoped_pipeline_activity_results if activity_result.name == name), None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,5 @@

class ControlActivity:

@staticmethod
def patch_generated_models(models):
models.ControlActivity.evaluate_control_activity_iterations = ControlActivity.evaluate_control_activity_iterations

def evaluate_control_activity_iterations(self, state: PipelineRunState, evaluate_activities: Callable[[PipelineRunState], Generator[Activity, None, None]]):
return []
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

class ExecutePipelineActivity:

@staticmethod
def patch_generated_models(models):
models.ExecutePipelineActivity.get_child_run_parameters = ExecutePipelineActivity.get_child_run_parameters

def get_child_run_parameters(self, state: PipelineRunState) -> List[RunParameter]:
child_parameters = []
for parameter in state.parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@

class ForEachActivity:

@staticmethod
def patch_generated_models(models):
models.ForEachActivity.evaluate = ForEachActivity.evaluate
models.ForEachActivity.evaluate_control_activity_iterations = ForEachActivity.evaluate_control_activity_iterations

def evaluate(self: ForEachActivity, state: PipelineRunState):
self.items.evaluate(state)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ class Expression:

evaluated_items: List[str] = []

@staticmethod
def patch_generated_models(models):
models.Expression.evaluate = Expression.evaluate

def evaluate(self: Expression, state: PipelineRunState):
self.evaluated_items = [
"item1",
Expand Down
16 changes: 11 additions & 5 deletions src/python/data_factory_testing_framework/models/patch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@

# Patch models with our custom classes
def patch_models():
Activity.patch_generated_models(_models)
ExecutePipelineActivity.patch_generated_models(_models)
ControlActivity.patch_generated_models(_models)
ForEachActivity.patch_generated_models(_models)
Expression.patch_generated_models(_models)
patch_model(_models.Activity, Activity)
patch_model(_models.ExecutePipelineActivity, ExecutePipelineActivity)
patch_model(_models.ControlActivity, ControlActivity)
patch_model(_models.ForEachActivity, ForEachActivity)
patch_model(_models.Expression, Expression)


def patch_model(main_class, partial_class):
partial_class_method_list = [attribute for attribute in dir(partial_class) if callable(getattr(partial_class, attribute)) and attribute.startswith('__') is False]
for method_name in partial_class_method_list:
setattr(main_class, method_name, getattr(partial_class, method_name))

0 comments on commit bbc159a

Please sign in to comment.