Skip to content

Commit

Permalink
allow passing cluster pool to run-tests.py (#4210)
Browse files Browse the repository at this point in the history
Signed-off-by: Iaroslav Ciupin <[email protected]>
  • Loading branch information
iaroslav-ciupin authored Oct 12, 2023
1 parent ba10ac0 commit db3a7ff
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions boilerplate/flyte/end2end/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import time
import traceback
from typing import Dict, List, Mapping, Tuple
from typing import Dict, List, Mapping, Tuple, Optional

import click
import requests
Expand Down Expand Up @@ -89,10 +89,10 @@
}


def execute_workflow(remote, version, workflow_name, inputs):
def execute_workflow(remote: FlyteRemote, version, workflow_name, inputs, cluster_pool_name: Optional[str] = None):
print(f"Fetching workflow={workflow_name} and version={version}")
wf = remote.fetch_workflow(name=workflow_name, version=version)
return remote.execute(wf, inputs=inputs, wait=False)
return remote.execute(wf, inputs=inputs, wait=False, cluster_pool=cluster_pool_name)


def executions_finished(executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]) -> bool:
Expand Down Expand Up @@ -125,6 +125,7 @@ def schedule_workflow_groups(
workflow_groups: List[str],
remote: FlyteRemote,
terminate_workflow_on_failure: bool,
cluster_pool_name: Optional[str] = None,
) -> Dict[str, bool]:
"""
Schedule workflows executions for all workflow groups and return True if all executions succeed, otherwise
Expand All @@ -135,7 +136,7 @@ def schedule_workflow_groups(
for wf_group in workflow_groups:
workflows = FLYTESNACKS_WORKFLOW_GROUPS.get(wf_group, [])
executions_by_wfgroup[wf_group] = [
execute_workflow(remote, tag, workflow[0], workflow[1]) for workflow in workflows
execute_workflow(remote, tag, workflow[0], workflow[1], cluster_pool_name) for workflow in workflows
]

# Wait for all executions to finish
Expand Down Expand Up @@ -179,6 +180,7 @@ def run(
priorities: List[str],
config_file_path,
terminate_workflow_on_failure: bool,
cluster_pool_name: Optional[str] = None,
) -> List[Dict[str, str]]:
remote = FlyteRemote(
Config.auto(config_file=config_file_path),
Expand Down Expand Up @@ -215,7 +217,7 @@ def run(
valid_workgroups.append(workflow_group)

results_by_wfgroup = schedule_workflow_groups(
flytesnacks_release_tag, valid_workgroups, remote, terminate_workflow_on_failure
flytesnacks_release_tag, valid_workgroups, remote, terminate_workflow_on_failure, cluster_pool_name
)

for workflow_group, succeeded in results_by_wfgroup.items():
Expand Down Expand Up @@ -261,15 +263,17 @@ def run(
@click.argument("flytesnacks_release_tag")
@click.argument("priorities")
@click.argument("config_file")
@click.argument("cluster_pool_name")
def cli(
flytesnacks_release_tag,
priorities,
config_file,
return_non_zero_on_failure,
terminate_workflow_on_failure,
cluster_pool_name,
):
print(f"return_non_zero_on_failure={return_non_zero_on_failure}")
results = run(flytesnacks_release_tag, priorities, config_file, terminate_workflow_on_failure)
results = run(flytesnacks_release_tag, priorities, config_file, terminate_workflow_on_failure, cluster_pool_name)

# Write a json object in its own line describing the result of this run to stdout
print(f"Result of run:\n{json.dumps(results)}")
Expand Down

0 comments on commit db3a7ff

Please sign in to comment.