Skip to content

Commit

Permalink
ci: Add ability to array-ify args and run multiple jobs (#3584)
Browse files Browse the repository at this point in the history
# Overview
Previously, the `run-cluster` workflow only ran one ray-job-submission.
This PR extends the ability to be able to run any arbitrary array of job
submissions by enabling us to pass an array into the `entrypoint_args`
input param. This then splits the command into its multiple pieces and
submits them all.

## Example Usage

```sh
gh workflow run run-cluster.yaml \
    --ref $current_branch \
    -f working_dir="." \
    -f daft_wheel_url="https://github-actions-artifacts-bucket.s3.us-west-2.amazonaws.com/builds/54428e3738e96764af60cfdd8a0e4a41717ec9f9/getdaft-0.3.0.dev0-cp38-abi3-manylinux_2_31_x86_64.whl" \
    -f entrypoint_script="benchmarking/tpcds/ray_entrypoint.py" \
    -f entrypoint_args="[\"--tpcds-gen-folder='gendata' --question='1'\", \"--tpcds-gen-folder='gendata' --question='2'\"]"
```

The above invocation runs TPC-DS queries 1 and 2.
  • Loading branch information
raunakab authored Dec 17, 2024
1 parent 47f5897 commit b7ea62b
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 36 deletions.
114 changes: 114 additions & 0 deletions .github/ci-scripts/job_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# /// script
# requires-python = ">=3.12"
# dependencies = []
# ///

import argparse
import asyncio
import json
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional

from ray.job_submission import JobStatus, JobSubmissionClient


def parse_env_var_str(env_var_str: str) -> dict:
iter = map(
lambda s: s.strip().split("="),
filter(lambda s: s, env_var_str.split(",")),
)
return {k: v for k, v in iter}


async def print_logs(logs):
async for lines in logs:
print(lines, end="")


async def wait_on_job(logs, timeout_s):
await asyncio.wait_for(print_logs(logs), timeout=timeout_s)


@dataclass
class Result:
query: int
duration: timedelta
error_msg: Optional[str]


def submit_job(
working_dir: Path,
entrypoint_script: str,
entrypoint_args: str,
env_vars: str,
enable_ray_tracing: bool,
):
env_vars_dict = parse_env_var_str(env_vars)
if enable_ray_tracing:
env_vars_dict["DAFT_ENABLE_RAY_TRACING"] = "1"

client = JobSubmissionClient(address="http://localhost:8265")

if entrypoint_args.startswith("[") and entrypoint_args.endswith("]"):
# this is a json-encoded list of strings; parse accordingly
list_of_entrypoint_args: list[str] = json.loads(entrypoint_args)
else:
list_of_entrypoint_args: list[str] = [entrypoint_args]

results = []

for index, args in enumerate(list_of_entrypoint_args):
entrypoint = f"DAFT_RUNNER=ray python {entrypoint_script} {args}"
print(f"{entrypoint=}")
start = datetime.now()
job_id = client.submit_job(
entrypoint=entrypoint,
runtime_env={
"working_dir": working_dir,
"env_vars": env_vars_dict,
},
)

asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30))

status = client.get_job_status(job_id)
assert status.is_terminal(), "Job should have terminated"
end = datetime.now()
duration = end - start
error_msg = None
if status != JobStatus.SUCCEEDED:
job_info = client.get_job_info(job_id)
error_msg = job_info.message

result = Result(query=index, duration=duration, error_msg=error_msg)
results.append(result)

print(f"{results=}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--working-dir", type=Path, required=True)
parser.add_argument("--entrypoint-script", type=str, required=True)
parser.add_argument("--entrypoint-args", type=str, required=True)
parser.add_argument("--env-vars", type=str, required=True)
parser.add_argument("--enable-ray-tracing", action="store_true")

args = parser.parse_args()

if not (args.working_dir.exists() and args.working_dir.is_dir()):
raise ValueError("The working-dir must exist and be a directory")

entrypoint: Path = args.working_dir / args.entrypoint_script
if not (entrypoint.exists() and entrypoint.is_file()):
raise ValueError("The entrypoint script must exist and be a file")

submit_job(
working_dir=args.working_dir,
entrypoint_script=args.entrypoint_script,
entrypoint_args=args.entrypoint_args,
env_vars=args.env_vars,
enable_ray_tracing=args.enable_ray_tracing,
)
2 changes: 2 additions & 0 deletions .github/ci-scripts/templatize_ray_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,7 @@ class Metadata(BaseModel, extra="allow"):
if metadata:
metadata = Metadata(**metadata)
content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies))
else:
content = content.replace(OTHER_INSTALL_PLACEHOLDER, "")

print(content)
36 changes: 14 additions & 22 deletions .github/workflows/run-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ on:
type: string
required: true
entrypoint_args:
description: Entry-point arguments
description: Entry-point arguments (either a simple string or a JSON list)
type: string
required: false
default: ""
Expand Down Expand Up @@ -79,24 +79,15 @@ jobs:
uv run \
--python 3.12 \
.github/ci-scripts/templatize_ray_config.py \
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
--daft-version '${{ inputs.daft_version }}' \
--python-version '${{ inputs.python_version }}' \
--cluster-profile '${{ inputs.cluster_profile }}' \
--working-dir '${{ inputs.working_dir }}' \
--entrypoint-script '${{ inputs.entrypoint_script }}'
--cluster-name="ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url='${{ inputs.daft_wheel_url }}' \
--daft-version='${{ inputs.daft_version }}' \
--python-version='${{ inputs.python_version }}' \
--cluster-profile='${{ inputs.cluster_profile }}' \
--working-dir='${{ inputs.working_dir }}' \
--entrypoint-script='${{ inputs.entrypoint_script }}'
) >> .github/assets/ray.yaml
cat .github/assets/ray.yaml
- name: Setup ray env vars
run: |
source .venv/bin/activate
ray_env_var=$(python .github/ci-scripts/format_env_vars.py \
--env-vars '${{ inputs.env_vars }}' \
--enable-ray-tracing \
)
echo $ray_env_var
echo "ray_env_var=$ray_env_var" >> $GITHUB_ENV
- name: Download private ssh key
run: |
KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text)
Expand All @@ -117,11 +108,12 @@ jobs:
echo 'Invalid command submitted; command cannot be empty'
exit 1
fi
ray job submit \
--working-dir ${{ inputs.working_dir }} \
--address http://localhost:8265 \
--runtime-env-json "$ray_env_var" \
-- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }}
python .github/ci-scripts/job_runner.py \
--working-dir='${{ inputs.working_dir }}' \
--entrypoint-script='${{ inputs.entrypoint_script }}' \
--entrypoint-args='${{ inputs.entrypoint_args }}' \
--env-vars='${{ inputs.env_vars }}' \
--enable-ray-tracing
- name: Download log files from ray cluster
run: |
source .venv/bin/activate
Expand Down
64 changes: 50 additions & 14 deletions benchmarking/tpcds/ray_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,54 @@
import argparse
from pathlib import Path

import helpers

import daft
from daft.sql.sql import SQLCatalog

TABLE_NAMES = [
"call_center",
"catalog_page",
"catalog_returns",
"catalog_sales",
"customer",
"customer_address",
"customer_demographics",
"date_dim",
"household_demographics",
"income_band",
"inventory",
"item",
"promotion",
"reason",
"ship_mode",
"store",
"store_returns",
"store_sales",
"time_dim",
"warehouse",
"web_page",
"web_returns",
"web_sales",
"web_site",
]


def register_catalog(scale_factor: int) -> SQLCatalog:
return SQLCatalog(
tables={
table: daft.read_parquet(
f"s3://eventual-dev-benchmarking-fixtures/uncompressed/tpcds-dbgen/{scale_factor}/{table}.parquet"
)
for table in TABLE_NAMES
}
)


def run(
parquet_folder: Path,
question: int,
dry_run: bool,
scale_factor: int,
):
catalog = helpers.generate_catalog(parquet_folder)
catalog = register_catalog(scale_factor)
query_file = Path(__file__).parent / "queries" / f"{question:02}.sql"
with open(query_file) as f:
query = f.read()
Expand All @@ -23,27 +60,26 @@ def run(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tpcds-gen-folder",
required=True,
type=Path,
help="Path to the TPC-DS data generation folder",
)
parser.add_argument(
"--question",
required=True,
type=int,
help="The TPC-DS question index to run",
required=True,
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Whether or not to run the query in dry-run mode; if true, only the plan will be printed out",
)
parser.add_argument(
"--scale-factor",
type=int,
help="Which scale factor to run this data at",
required=False,
default=2,
)
args = parser.parse_args()

tpcds_gen_folder: Path = args.tpcds_gen_folder
assert tpcds_gen_folder.exists()
assert args.question in range(1, 100)

run(args.tpcds_gen_folder, args.question, args.dry_run)
run(question=args.question, dry_run=args.dry_run, scale_factor=args.scale_factor)

0 comments on commit b7ea62b

Please sign in to comment.