Skip to content

Commit

Permalink
(shortfin-sd) add e2e test + fixes for batched requests (#343)
Browse files Browse the repository at this point in the history
Put some railings around concurrent program invocations, fix batched
image responses.

Switches to a new scheduler module that lets us run more than one quantity of unet iterations.

Adds e2e test and workflow for SDXL shortfin serving on MI300x.

---------

Co-authored-by: Ean Garvey <[email protected]>
  • Loading branch information
monorimet and eagarvey-amd authored Oct 29, 2024
1 parent 072be20 commit e465c83
Show file tree
Hide file tree
Showing 19 changed files with 586 additions and 144 deletions.
102 changes: 102 additions & 0 deletions .github/workflows/ci-sdxl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: CI - shortfin - SDXL

on:
workflow_dispatch:
pull_request:
paths:
- '.github/workflows/ci-sdxl.yaml'
- 'shortfin/**'
push:
branches:
- main
paths:
- '.github/workflows/ci-sdxl.yaml'
- 'shortfin/**'

permissions:
contents: read

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

env:
IREE_REPO_DIR: ${{ github.workspace }}/iree
LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/

jobs:
build-and-test:
name: Build and test
runs-on: mi300-sdxl-kernel

steps:
- name: Install dependencies
run: |
sudo apt update -y
sudo apt install cmake ninja-build -y
- name: Checkout repository
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
with:
submodules: false

- name: Checkout IREE repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
with:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
run : |
git submodule update --init --depth 1 -- third_party/benchmark
git submodule update --init --depth 1 -- third_party/cpuinfo/
git submodule update --init --depth 1 -- third_party/flatcc
git submodule update --init --depth 1 -- third_party/googletest
git submodule update --init --depth 1 -- third_party/hip-build-deps/
- name: Setup Python
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
with:
python-version: "3.12"
cache: "pip"
- name: Install Python packages
# TODO: Switch to `pip install -r requirements.txt -e shortfin/`.
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
pip install -r requirements-tests.txt
pip install -r requirements-iree-compiler.txt
pip freeze
- name: Build shortfin (full)
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
mkdir build
cmake -GNinja \
-S. \
-Bbuild \
-DCMAKE_C_COMPILER=clang-18 \
-DCMAKE_CXX_COMPILER=clang++-18 \
-DSHORTFIN_BUNDLE_DEPS=ON \
-DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
cmake --build build --target all
pip install -v -e build/
- name: Test shortfin (full)
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
ctest --timeout 30 --output-on-failure --test-dir build
pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu
8 changes: 0 additions & 8 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,9 @@ name: CI - sharktank
on:
workflow_dispatch:
pull_request:
paths:
- '.github/workflows/ci-sharktank.yml'
- 'sharktank/**'
- '*requirements*.txt'
push:
branches:
- main
paths:
- '.github/workflows/ci-sharktank.yml'
- 'sharktank/**'
- '*requirements*.txt'

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: candidate-20241025.1058
ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64_asan-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_SOURCE_DIR }}
submodules: false
ref: candidate-20241025.1058
ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_SOURCE_DIR }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci_linux_x64_nogil-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
ref: candidate-20241025.1058
ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af

- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
Expand Down Expand Up @@ -98,6 +98,6 @@ jobs:
- name: Run shortfin Python tests (full)
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py
pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd
# TODO: Enable further tests and switch to
# pytest -s
2 changes: 1 addition & 1 deletion shortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ elseif (SHORTFIN_BUNDLE_DEPS)
FetchContent_Declare(
shortfin_iree
GIT_REPOSITORY https://github.com/iree-org/iree.git
GIT_TAG candidate-20241025.1058
GIT_TAG 67ba1c45424d5cedc7baf7bfe8a998ee86e510af
GIT_SUBMODULES ${IREE_SUBMODULES}
GIT_SHALLOW TRUE
SYSTEM
Expand Down
12 changes: 5 additions & 7 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,20 @@ python -m shortfin_apps.sd.server --help
- Download runtime artifacts (vmfbs, weights):

```
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/sfsd_vmfbs_gfx942_1023.zip
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip
# Option to use splat weights instead
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_splat_1023.zip
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/sfsd_vmfbs_gfx942_1028.zip
# The sfsd_vmfbs_gfx942_1028.zip includes splat weights. You can download real weights with:
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip
```
- Unzip the downloaded archives to ./vmfbs and /weights
- Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples):

```
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_20_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0
```
with splat:
```
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_20_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0
```
- Run a request in a separate shell:
```
Expand Down
11 changes: 4 additions & 7 deletions shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import io
import logging
import json

import shortfin as sf
import shortfin.array as sfnp
Expand Down Expand Up @@ -95,12 +96,8 @@ async def run(self):

# TODO: stream image outputs
logging.debug("Responding to one shot batch")
out = io.BytesIO()
result_images = [p.result_image for p in gen_processes]
for idx, result_image in enumerate(result_images):
out.write(result_image)
# TODO: save or return images
logging.debug("Wrote images as bytes to response.")
self.responder.send_response(out.getvalue())
response_data = {"images": [p.result_image for p in gen_processes]}
json_str = json.dumps(response_data)
self.responder.send_response(json_str)
finally:
self.responder.ensure_response()
40 changes: 23 additions & 17 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,30 @@ def post_init(self):
):
raise ValueError("Either text or input_ids should be provided.")

prev_input_len = None
for i in [self.prompt, self.neg_prompt, self.input_ids, self.neg_input_ids]:
if isinstance(i, str):
self.num_output_images = 1
continue
elif not i:
continue
if not isinstance(i, list):
raise ValueError("Text inputs should be strings or lists.")
if prev_input_len and not (prev_input_len == len(i)):
raise ValueError("Positive, Negative text inputs should be same length")
self.num_output_images = len(i)
prev_input_len = len(i)
if not self.num_output_images:
self.num_output_images = (
len[self.prompt] if self.prompt is not None else len(self.input_ids)
)
if isinstance(self.prompt, str):
self.prompt = [str]

self.num_output_images = (
len(self.prompt) if self.prompt is not None else len(self.input_ids)
)

batchable_args = [
self.prompt,
self.neg_prompt,
self.height,
self.width,
self.steps,
self.guidance_scale,
self.seed,
self.input_ids,
self.neg_input_ids,
]
for arg in batchable_args:
if isinstance(arg, list):
if len(arg) != self.num_output_images and len(arg) != 1:
raise ValueError(
f"Batchable arguments should either be singular or as many as the full batch ({self.num_output_images})."
)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)]
else:
Expand Down
18 changes: 14 additions & 4 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

from enum import Enum

import logging

import shortfin as sf
import shortfin.array as sfnp

from .io_struct import GenerateReqInput

logger = logging.getLogger(__name__)


class InferencePhase(Enum):
# Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays
Expand Down Expand Up @@ -54,6 +58,8 @@ def __init__(
image_array: sfnp.device_array | None = None,
):
super().__init__()
self.print_debug = True

self.phases = {}
self.phase = None
self.height = height
Expand Down Expand Up @@ -87,6 +93,7 @@ def __init__(
self.image_array = image_array

self.result_image = None
self.img_metadata = None

self.done = sf.VoidFuture()

Expand All @@ -96,8 +103,6 @@ def __init__(
self.return_host_array: bool = True

self.post_init()
print(self.phases)
print(self.phase)

@staticmethod
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
Expand All @@ -114,8 +119,13 @@ def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest":
for item in gen_inputs:
received = getattr(gen_req, item, None)
if isinstance(received, list):
if index >= len(received):
rec_input = None
if index >= (len(received)):
if len(received) == 1:
rec_input = received[0]
else:
logging.error(
"Inputs in request must be singular or as many as the list of prompts."
)
else:
rec_input = received[index]
else:
Expand Down
Loading

0 comments on commit e465c83

Please sign in to comment.