Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automate artifact generations #131

Merged
merged 17 commits into from
Oct 4, 2023
72 changes: 72 additions & 0 deletions .github/workflows/update_model_artifacts.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# Workflow that updates model artifacts.

name: Update Model Artifacts

on:
workflow_dispatch:

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit).
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
setup:
runs-on: ubuntu-22.04
outputs:
runner-group: ${{ steps.configure.outputs.runner-group }}
steps:
- name: "Checking out PR repository"
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0
- name: "Configuring CI options"
id: configure
env:
RUNNER_GROUP: ${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}
run: |
# Just informative logging. There should only be two commits in the
# history here, but limiting the depth helps when copying from a local
# repo instead of using checkout, e.g. with
# https://github.com/nektos/act where there will be more.
git log --oneline --graph --max-count=3
# Workflow jobs can't access `env` in `runs-on`, so we need to make
# `runner-group` a job output variable.
echo "runner-group=${RUNNER_GROUP}" > "${GITHUB_OUTPUT}"

generate_artifacts:
needs: setup
runs-on:
- self-hosted # must come first
- runner-group=${{ needs.setup.outputs.runner-group }}
- environment=prod
- machine-type=c2-standard-16
steps:
- name: "Checking out PR repository"
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0
- name: "Generate JAX model artifacts"
run: |
mkdir jax
# Generate enabled models in comparative_benchmark/jax/benchmark_xla.sh
docker run --mount="type=bind,src="${PWD}",target=/work" --workdir="/work" \
--env "PYTHON=python3" \
--env "WITH_CUDA=0" \
--env "AUTO_UPLOAD=0" \
--env "JOBS=1" \
--env "OUTPUT_DIR=jax" \
"gcr.io/iree-oss/openxla-benchmark/base@sha256:1bf3e319465ec8fb465baae3f6ba9a5b09cb84a5349a675c671a552fc77f2251" \
"common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.sh" \
"RESNET50_FP32_JAX_.+" \
"BERT_LARGE_FP32_JAX_.+_BATCH(1|16|24|32|48|64|512)" \
"T5_LARGE_FP32_JAX_.+_BATCH(1|16|24|32|48|64)" \
"T5_4CG_LARGE_FP32_JAX_.+" \
"GPT2LMHEAD_FP32_JAX_.+"
- name: "Upload JAX model artifacts"
run: |
gcloud storage cp -r "jax/*" "gs://iree-model-artifacts/jax"
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import concurrent.futures
import jax
import os
import pathlib
import re
import multiprocessing
import shutil
import subprocess
import sys
from typing import Any, Optional
from typing import Any, List, Optional

# Add openxla dir to the search path.
sys.path.insert(0, str(pathlib.Path(__file__).parents[5]))
Expand Down Expand Up @@ -105,9 +105,10 @@ def _parse_arguments() -> argparse.Namespace:
help="Directory to save model artifacts.")
parser.add_argument("-f",
"--filter",
type=str,
default=".*",
help="The regex pattern to filter model names.")
dest="filters",
nargs="+",
default=[".*"],
help="The regex patterns to filter model names.")
parser.add_argument("--iree-ir-tool",
"--iree_ir_tool",
type=pathlib.Path,
Expand All @@ -120,12 +121,20 @@ def _parse_arguments() -> argparse.Namespace:
help=
f"If set, uploads artifacts automatically to {GCS_UPLOAD_DIR} and removes them locally once uploaded."
)
parser.add_argument(
"-j",
"--jobs",
type=int,
default=1,
help="Max number of concurrent jobs to generate artifacts. Be cautious"
" when generating with GPU.")
return parser.parse_args()


def main(output_dir: pathlib.Path, filter: str, iree_ir_tool: pathlib.Path,
auto_upload: bool):
name_pattern = re.compile(f"^{filter}$")
def main(output_dir: pathlib.Path, filters: List[str],
iree_ir_tool: pathlib.Path, auto_upload: bool, jobs: int):
combined_filters = "|".join(f"({name_filter})" for name_filter in filters)
name_pattern = re.compile(f"^{combined_filters}$")
models = [
model for model in model_definitions.ALL_MODELS
if name_pattern.match(model.name)
Expand All @@ -134,19 +143,20 @@ def main(output_dir: pathlib.Path, filter: str, iree_ir_tool: pathlib.Path,
if not models:
all_models_list = "\n".join(
model.name for model in model_definitions.ALL_MODELS)
raise ValueError(f'No model matches "{filter}".'
raise ValueError(f'No model matches "{filters}".'
f' Available models:\n{all_models_list}')

output_dir.mkdir(parents=True, exist_ok=True)

for model in models:
# We need to generate artifacts in a separate proces each time in order for
# XLA to update the HLO dump directory.
p = multiprocessing.Process(target=_generate_artifacts,
args=(model, output_dir, iree_ir_tool,
auto_upload))
p.start()
p.join()
with concurrent.futures.ProcessPoolExecutor(max_workers=jobs) as executor:
for model in models:
# We need to generate artifacts in a separate proces each time in order for
# XLA to update the HLO dump directory.
executor.submit(_generate_artifacts,
model=model,
save_dir=output_dir,
iree_ir_tool=iree_ir_tool,
auto_upload=auto_upload)

if auto_upload:
utils.gcs_upload(f"{output_dir}/**", f"{GCS_UPLOAD_DIR}/{output_dir.name}/")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# Runs `generate_model_artifacts.py` on all registered JAX models and saves
# artifacts into the directory `/tmp/jax_models_<jax-version>_<timestamp>`.
# artifacts into the directory
# `${OUTPUT_DIR}/jax_models_<jax-version>_<timestamp>`.
#
# Once complete. please upload the output directory to
# `gs://iree-model-artifacts/jax`, preserving directory name.
Expand All @@ -23,6 +24,7 @@
# WITH_CUDA=1
# GCS_UPLOAD_DIR=gs://iree-model-artifacts/jax
# AUTO_UPLOAD=1
# JOBS=1
#
# Positional arguments:
# FILTER (Optional): Regex to match models, e.g., BERT_LARGE_FP32_.+
Expand All @@ -34,8 +36,9 @@ VENV_DIR="${VENV_DIR:-jax-models.venv}"
PYTHON="${PYTHON:-"$(which python)"}"
WITH_CUDA="${WITH_CUDA:-}"
AUTO_UPLOAD="${AUTO_UPLOAD:-0}"

FILTER="${1:-".*"}"
OUTPUT_DIR="${OUTPUT_DIR:-/tmp}"
JOBS="${JOBS:-1}"
FILTER=( "$@" )

VENV_DIR=${VENV_DIR} PYTHON=${PYTHON} WITH_CUDA=${WITH_CUDA} "${TD}/setup_venv.sh"
source ${VENV_DIR}/bin/activate
Expand All @@ -46,23 +49,25 @@ PYTHON_VERSION="$(python --version | sed -e "s/^Python \(.*\)\.\(.*\)\..*$/\1\.\
# Generate unique output directory.
JAX_VERSION=$(pip show jax | grep Version | sed -e "s/^Version: \(.*\)$/\1/g")
DIR_NAME="jax_models_${JAX_VERSION}_$(date +'%s')"
OUTPUT_DIR="/tmp/${DIR_NAME}"
mkdir "${OUTPUT_DIR}"
VERSION_DIR="${OUTPUT_DIR}/${DIR_NAME}"
mkdir "${VERSION_DIR}"

pip list > "${OUTPUT_DIR}/models_version_info.txt"
pip list > "${VERSION_DIR}/models_version_info.txt"

declare -a args=(
-o "${OUTPUT_DIR}"
-o "${VERSION_DIR}"
--iree_ir_tool="$(which iree-ir-tool)"
--filter="${FILTER}"
--jobs="${JOBS}"
)

if (( "${#FILTER[@]}" > 0 )); then
args+=( --filter "${FILTER[@]}" )
fi

if (( AUTO_UPLOAD == 1 )); then
args+=(
--auto_upload
)
args+=( --auto_upload )
fi

python "${TD}/generate_model_artifacts.py" "${args[@]}"

echo "Output directory: ${OUTPUT_DIR}"
echo "Output directory: ${VERSION_DIR}"
Loading