Skip to content

Commit

Permalink
[FEAT] shuffle testing (#3492)
Browse files Browse the repository at this point in the history
# Overview
Add @colin-ho's shuffle testing PR to this repo; also add the ability to
invoke it in CI runs.

---------

Co-authored-by: Desmond Cheong <[email protected]>
  • Loading branch information
Raunak Bhagat and desmondcheongzx authored Dec 5, 2024
1 parent 78c738f commit c2abed8
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ setup_commands:
- uv v
- echo "source $HOME/.venv/bin/activate" >> $HOME/.bashrc
- source .venv/bin/activate
- uv pip install pip ray[default] py-spy \{{DAFT_INSTALL}}
- uv pip install pip ray[default] py-spy \{{DAFT_INSTALL}} \{{OTHER_INSTALLS}}
38 changes: 38 additions & 0 deletions .github/ci-scripts/format_env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Given a comma-separated string of environment variables, parse them into a dictionary.
Example:
env_str = "a=1,b=2"
result = parse_env_var_str(env_str)
# returns {"a":1,"b":2}
"""

import argparse
import json


def parse_env_var_str(env_var_str: str) -> dict:
iter = map(
lambda s: s.strip().split("="),
filter(lambda s: s, env_var_str.split(",")),
)
return {k: v for k, v in iter}


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-ray-tracing", action="store_true")
parser.add_argument("--env-vars", required=True)
args = parser.parse_args()

env_vars = parse_env_var_str(args.env_vars)
if args.enable_ray_tracing:
env_vars["DAFT_ENABLE_RAY_TRACING"] = "1"
ray_env_vars = {
"env_vars": env_vars,
}
print(json.dumps(ray_env_vars))


if __name__ == "__main__":
main()
30 changes: 30 additions & 0 deletions .github/ci-scripts/read_inline_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# /// script
# requires-python = ">=3.12"
# dependencies = []
# ///

"""
The `read` function below is sourced from:
https://packaging.python.org/en/latest/specifications/inline-script-metadata/#inline-script-metadata
"""

import re

import tomllib

REGEX = r"(?m)^# /// (?P<type>[a-zA-Z0-9-]+)$\s(?P<content>(^#(| .*)$\s)+)^# ///$"


def read(script: str) -> dict | None:
name = "script"
matches = list(filter(lambda m: m.group("type") == name, re.finditer(REGEX, script)))
if len(matches) > 1:
raise ValueError(f"Multiple {name} blocks found")
elif len(matches) == 1:
content = "".join(
line[2:] if line.startswith("# ") else line[1:]
for line in matches[0].group("content").splitlines(keepends=True)
)
return tomllib.loads(content)
else:
return None
68 changes: 42 additions & 26 deletions .github/ci-scripts/templatize_ray_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
# /// script
# requires-python = ">=3.12"
# dependencies = ['pydantic']
# ///

import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import read_inline_metadata
from pydantic import BaseModel, Field

CLUSTER_NAME_PLACEHOLDER = "\\{{CLUSTER_NAME}}"
DAFT_INSTALL_PLACEHOLDER = "\\{{DAFT_INSTALL}}"
OTHER_INSTALL_PLACEHOLDER = "\\{{OTHER_INSTALLS}}"
PYTHON_VERSION_PLACEHOLDER = "\\{{PYTHON_VERSION}}"
CLUSTER_PROFILE__NODE_COUNT = "\\{{CLUSTER_PROFILE/node_count}}"
CLUSTER_PROFILE__INSTANCE_TYPE = "\\{{CLUSTER_PROFILE/instance_type}}"
CLUSTER_PROFILE__IMAGE_ID = "\\{{CLUSTER_PROFILE/image_id}}"
CLUSTER_PROFILE__SSH_USER = "\\{{CLUSTER_PROFILE/ssh_user}}"
CLUSTER_PROFILE__VOLUME_MOUNT = "\\{{CLUSTER_PROFILE/volume_mount}}"

NOOP_STEP = "echo 'noop step; skipping'"


@dataclass
class Profile:
Expand All @@ -22,6 +34,11 @@ class Profile:
volume_mount: Optional[str] = None


class Metadata(BaseModel, extra="allow"):
dependencies: list[str] = Field(default_factory=list)
env: dict[str, str] = Field(default_factory=dict)


profiles: dict[str, Optional[Profile]] = {
"debug_xs-x86": Profile(
instance_type="t3.large",
Expand Down Expand Up @@ -50,15 +67,16 @@ class Profile:
content = sys.stdin.read()

parser = ArgumentParser()
parser.add_argument("--cluster-name")
parser.add_argument("--cluster-name", required=True)
parser.add_argument("--daft-wheel-url")
parser.add_argument("--daft-version")
parser.add_argument("--python-version")
parser.add_argument("--cluster-profile")
parser.add_argument("--python-version", required=True)
parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86"])
parser.add_argument("--working-dir", required=True)
parser.add_argument("--entrypoint-script", required=True)
args = parser.parse_args()

if args.cluster_name:
content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name)
content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name)

if args.daft_wheel_url and args.daft_version:
raise ValueError(
Expand All @@ -72,26 +90,24 @@ class Profile:
daft_install = "getdaft"
content = content.replace(DAFT_INSTALL_PLACEHOLDER, daft_install)

if args.python_version:
content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version)

if cluster_profile := args.cluster_profile:
cluster_profile: str
if cluster_profile not in profiles:
raise Exception(f'Cluster profile "{cluster_profile}" not found')

profile = profiles[cluster_profile]
if profile is None:
raise Exception(f'Cluster profile "{cluster_profile}" not yet implemented')

assert profile is not None
content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count))
content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type)
content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id)
content = content.replace(CLUSTER_PROFILE__SSH_USER, profile.ssh_user)
if profile.volume_mount:
content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount)
else:
content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, "echo 'Nothing to mount; skipping'")
content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version)

profile = profiles[args.cluster_profile]
content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count))
content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type)
content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id)
content = content.replace(CLUSTER_PROFILE__SSH_USER, profile.ssh_user)
content = content.replace(
CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount if profile.volume_mount else NOOP_STEP
)

working_dir = Path(args.working_dir)
assert working_dir.exists() and working_dir.is_dir()
entrypoint_script_fullpath: Path = working_dir / args.entrypoint_script
assert entrypoint_script_fullpath.exists() and entrypoint_script_fullpath.is_file()
with open(entrypoint_script_fullpath) as f:
metadata = Metadata(**read_inline_metadata.read(f.read()))

content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies))

print(content)
63 changes: 44 additions & 19 deletions .github/workflows/run-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,45 @@ on:
workflow_dispatch:
inputs:
daft_wheel_url:
description: Daft python-wheel URL
type: string
description: A public https url pointing directly to a daft python-wheel to install
required: false
daft_version:
description: Daft version (errors if both this and "Daft python-wheel URL" are provided)
type: string
description: A released version of daft on PyPi to install (errors if both this and `daft_wheel_url` are provided)
required: false
python_version:
description: Python version
type: string
description: The version of python to use
required: false
default: "3.9"
cluster_profile:
description: Cluster profile
type: choice
options:
- medium-x86
- debug_xs-x86
description: The profile to use for the cluster
required: false
default: medium-x86
command:
type: string
description: The command to run on the cluster
required: true
working_dir:
description: Working directory
type: string
description: The working directory to submit to the cluster
required: false
default: .github/working-dir
entrypoint_script:
description: Entry-point python script (must be inside of the working directory)
type: string
required: true
entrypoint_args:
description: Entry-point arguments
type: string
required: false
default: ""
env_vars:
description: Environment variables
type: string
required: false
default: ""

jobs:
run-command:
Expand All @@ -42,6 +52,8 @@ jobs:
id-token: write
contents: read
steps:
- name: Log workflow inputs
run: echo "${{ toJson(github.event.inputs) }}"
- name: Checkout repo
uses: actions/checkout@v4
with:
Expand All @@ -63,15 +75,28 @@ jobs:
- name: Dynamically update ray config file
run: |
source .venv/bin/activate
(cat .github/assets/.template.yaml \
| python .github/ci-scripts/templatize_ray_config.py \
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
--daft-version '${{ inputs.daft_version }}' \
--python-version '${{ inputs.python_version }}' \
--cluster-profile '${{ inputs.cluster_profile }}'
(cat .github/assets/template.yaml | \
uv run \
--python 3.12 \
.github/ci-scripts/templatize_ray_config.py \
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
--daft-version '${{ inputs.daft_version }}' \
--python-version '${{ inputs.python_version }}' \
--cluster-profile '${{ inputs.cluster_profile }}' \
--working-dir '${{ inputs.working_dir }}' \
--entrypoint-script '${{ inputs.entrypoint_script }}'
) >> .github/assets/ray.yaml
cat .github/assets/ray.yaml
- name: Setup ray env vars
run: |
source .venv/bin/activate
ray_env_var=$(python .github/ci-scripts/format_env_vars.py \
--env-vars '${{ inputs.env_vars }}' \
--enable-ray-tracing \
)
echo $ray_env_var
echo "ray_env_var=$ray_env_var" >> $GITHUB_ENV
- name: Download private ssh key
run: |
KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text)
Expand All @@ -88,15 +113,15 @@ jobs:
- name: Submit job to ray cluster
run: |
source .venv/bin/activate
if [[ -z '${{ inputs.command }}' ]]; then
if [[ -z '${{ inputs.entrypoint_script }}' ]]; then
echo 'Invalid command submitted; command cannot be empty'
exit 1
fi
ray job submit \
--working-dir ${{ inputs.working_dir }} \
--address http://localhost:8265 \
--runtime-env-json '{"env_vars": {"DAFT_ENABLE_RAY_TRACING": "1"}}' \
-- ${{ inputs.command }}
--runtime-env-json "$ray_env_var" \
-- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }}
- name: Download log files from ray cluster
run: |
source .venv/bin/activate
Expand Down
Loading

0 comments on commit c2abed8

Please sign in to comment.