Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Sep 15, 2023
1 parent ae675fd commit 55d0f6a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 24 deletions.
21 changes: 7 additions & 14 deletions .github/workflows/update_model_artifacts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
name: Update Model Artifacts

on:
schedule:
# Scheduled to run at 09:00 UTC.
- cron: '0 09 * * *'
workflow_dispatch:
pull_request:

Expand Down Expand Up @@ -54,26 +51,22 @@ jobs:
steps:
- name: "Checking out PR repository"
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0
# - name: "Generate JAX model artifacts"
# run: |
# mkdir jax
# docker run --gpus all --mount="type=bind,src="${PWD}",target=/work" --workdir="/work" \
# --env "PYTHON=python3" \
# --env "WITH_CUDA=1" \
# --env "AUTO_UPLOAD=0" \
# --env "OUTPUT_DIR=jax" \
# "gcr.io/iree-oss/openxla-benchmark/cuda11.8-cudnn8.9@sha256:f43984cd6c16ad1faad4dfb6aac3f53e552dd728c9330c90752e78ae51e4276f" \
# "common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.sh"
- name: "Generate JAX model artifacts"
run: |
mkdir jax
# Generate enabled models in comparative_benchmark/jax/benchmark_xla.sh
docker run --mount="type=bind,src="${PWD}",target=/work" --workdir="/work" \
--env "PYTHON=python3" \
--env "WITH_CUDA=0" \
--env "AUTO_UPLOAD=0" \
--env "OUTPUT_DIR=jax" \
"gcr.io/iree-oss/openxla-benchmark/base@sha256:1bf3e319465ec8fb465baae3f6ba9a5b09cb84a5349a675c671a552fc77f2251" \
"common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.sh"
"common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.sh" --filter \
"RESNET50_FP32_JAX_.+" \
"BERT_LARGE_FP32_JAX_.+_BATCH(1|16|24|32|48|64|512)" \
"T5_LARGE_FP32_JAX_.+_BATCH(1|16|24|32|48|64)" \
"T5_4CG_LARGE_FP32_JAX_.+" \
"GPT2LMHEAD_FP32_JAX_.+"
- name: "Upload JAX model artifacts"
run: |
gcloud storage cp -r "jax/*" "gs://iree-model-artifacts/jax"
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ declare -a args=(

if (( "${#FILTER[@]}" > 0 )); then
args+=( --filter "${FILTER[@]}" )
else
# Generate enabled models in comparative_benchmark/jax/benchmark_xla.sh
args+=(
--filter
"RESNET50_FP32_JAX_.+"
"BERT_LARGE_FP32_JAX_.+_BATCH(1|16|24|32|48|64|512)"
"T5_LARGE_FP32_JAX_.+_BATCH(1|16|24|32|48)"
"T5_4CG_LARGE_FP32_JAX_.+"
"GPT2LMHEAD_FP32_JAX_.+"
)
fi

if (( AUTO_UPLOAD == 1 )); then
Expand Down

0 comments on commit 55d0f6a

Please sign in to comment.