From e465c8331e78920ceadf9ea700c9487340628771 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 29 Oct 2024 09:43:39 -0500 Subject: [PATCH] (shortfin-sd) add e2e test + fixes for batched requests (#343) Put some railings around concurrent program invocations, fix batched image responses. Switches to a new scheduler module that lets us run more than one quantity of unet iterations. Adds e2e test and workflow for SDXL shortfin serving on MI300x. --------- Co-authored-by: Ean Garvey --- .github/workflows/ci-sdxl.yaml | 102 +++++++++ .github/workflows/ci-sharktank.yml | 8 - .../workflows/ci_linux_x64-libshortfin.yml | 2 +- .../ci_linux_x64_asan-libshortfin.yml | 2 +- .../ci_linux_x64_nogil-libshortfin.yml | 4 +- shortfin/CMakeLists.txt | 2 +- shortfin/python/shortfin_apps/sd/README.md | 12 +- .../shortfin_apps/sd/components/generate.py | 11 +- .../shortfin_apps/sd/components/io_struct.py | 40 ++-- .../shortfin_apps/sd/components/messages.py | 18 +- .../shortfin_apps/sd/components/service.py | 111 +++++----- .../sd/examples/sdxl_request_bs2.json | 18 ++ .../sd/examples/sdxl_request_bs4.json | 22 ++ .../sd/examples/sdxl_request_bs8.json | 42 +--- .../shortfin_apps/sd/examples/send_request.py | 59 +++-- shortfin/requirements-tests.txt | 4 + .../apps/sd/components/tokenizer_test.py | 55 +++++ shortfin/tests/apps/sd/conftest.py | 17 ++ shortfin/tests/apps/sd/e2e_test.py | 201 ++++++++++++++++++ 19 files changed, 586 insertions(+), 144 deletions(-) create mode 100644 .github/workflows/ci-sdxl.yaml create mode 100644 shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json create mode 100644 shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json create mode 100644 shortfin/tests/apps/sd/components/tokenizer_test.py create mode 100644 shortfin/tests/apps/sd/conftest.py create mode 100644 shortfin/tests/apps/sd/e2e_test.py diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml new file mode 100644 index 000000000..c17e5b67f --- /dev/null +++ b/.github/workflows/ci-sdxl.yaml @@ -0,0 +1,102 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - shortfin - SDXL + +on: + workflow_dispatch: + pull_request: + paths: + - '.github/workflows/ci-sdxl.yaml' + - 'shortfin/**' + push: + branches: + - main + paths: + - '.github/workflows/ci-sdxl.yaml' + - 'shortfin/**' + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +env: + IREE_REPO_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ + +jobs: + build-and-test: + name: Build and test + runs-on: mi300-sdxl-kernel + + steps: + - name: Install dependencies + run: | + sudo apt update -y + sudo apt install cmake ninja-build -y + + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + repository: iree-org/iree + path: ${{ env.IREE_REPO_DIR }} + submodules: false + ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af + + - name: Initalize IREE submodules + working-directory: ${{ env.IREE_REPO_DIR }} + run : | + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Setup Python + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 + with: + python-version: "3.12" + cache: "pip" + - name: Install Python packages + # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + pip install -r requirements-tests.txt + pip install -r requirements-iree-compiler.txt + pip freeze + + - name: Build shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + mkdir build + cmake -GNinja \ + -S. \ + -Bbuild \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + cmake --build build --target all + pip install -v -e build/ + + - name: Test shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + ctest --timeout 30 --output-on-failure --test-dir build + pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 73243086a..d12d2c58c 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -3,17 +3,9 @@ name: CI - sharktank on: workflow_dispatch: pull_request: - paths: - - '.github/workflows/ci-sharktank.yml' - - 'sharktank/**' - - '*requirements*.txt' push: branches: - main - paths: - - '.github/workflows/ci-sharktank.yml' - - 'sharktank/**' - - '*requirements*.txt' concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 6be5ecec8..1bb4da7f3 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -56,7 +56,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: candidate-20241025.1058 + ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index 6d916c2f5..ce8479c22 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -109,7 +109,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_SOURCE_DIR }} submodules: false - ref: candidate-20241025.1058 + ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af - name: Initalize IREE submodules working-directory: ${{ env.IREE_SOURCE_DIR }} diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 6594733be..571157951 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -57,7 +57,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: candidate-20241025.1058 + ref: 67ba1c45424d5cedc7baf7bfe8a998ee86e510af - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd # TODO: Enable further tests and switch to # pytest -s diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index b3c2ee24f..0673d6ff9 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -183,7 +183,7 @@ elseif (SHORTFIN_BUNDLE_DEPS) FetchContent_Declare( shortfin_iree GIT_REPOSITORY https://github.com/iree-org/iree.git - GIT_TAG candidate-20241025.1058 + GIT_TAG 67ba1c45424d5cedc7baf7bfe8a998ee86e510af GIT_SUBMODULES ${IREE_SUBMODULES} GIT_SHALLOW TRUE SYSTEM diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 60e56c778..6a1011c13 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -23,22 +23,20 @@ python -m shortfin_apps.sd.server --help - Download runtime artifacts (vmfbs, weights): ``` -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/sfsd_vmfbs_gfx942_1023.zip -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip - -# Option to use splat weights instead -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_splat_1023.zip +wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/sfsd_vmfbs_gfx942_1028.zip +# The sfsd_vmfbs_gfx942_1028.zip includes splat weights. You can download real weights with: +wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip ``` - Unzip the downloaded archives to ./vmfbs and /weights - Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples): ``` -python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_20_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0 +python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0 ``` with splat: ``` -python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_20_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0 +python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0 ``` - Run a request in a separate shell: ``` diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index 347e4cb3b..ca4f9799d 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -7,6 +7,7 @@ import asyncio import io import logging +import json import shortfin as sf import shortfin.array as sfnp @@ -95,12 +96,8 @@ async def run(self): # TODO: stream image outputs logging.debug("Responding to one shot batch") - out = io.BytesIO() - result_images = [p.result_image for p in gen_processes] - for idx, result_image in enumerate(result_images): - out.write(result_image) - # TODO: save or return images - logging.debug("Wrote images as bytes to response.") - self.responder.send_response(out.getvalue()) + response_data = {"images": [p.result_image for p in gen_processes]} + json_str = json.dumps(response_data) + self.responder.send_response(json_str) finally: self.responder.ensure_response() diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py index e69bc8d82..d2952a818 100644 --- a/shortfin/python/shortfin_apps/sd/components/io_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -41,24 +41,30 @@ def post_init(self): ): raise ValueError("Either text or input_ids should be provided.") - prev_input_len = None - for i in [self.prompt, self.neg_prompt, self.input_ids, self.neg_input_ids]: - if isinstance(i, str): - self.num_output_images = 1 - continue - elif not i: - continue - if not isinstance(i, list): - raise ValueError("Text inputs should be strings or lists.") - if prev_input_len and not (prev_input_len == len(i)): - raise ValueError("Positive, Negative text inputs should be same length") - self.num_output_images = len(i) - prev_input_len = len(i) - if not self.num_output_images: - self.num_output_images = ( - len[self.prompt] if self.prompt is not None else len(self.input_ids) - ) + if isinstance(self.prompt, str): + self.prompt = [str] + self.num_output_images = ( + len(self.prompt) if self.prompt is not None else len(self.input_ids) + ) + + batchable_args = [ + self.prompt, + self.neg_prompt, + self.height, + self.width, + self.steps, + self.guidance_scale, + self.seed, + self.input_ids, + self.neg_input_ids, + ] + for arg in batchable_args: + if isinstance(arg, list): + if len(arg) != self.num_output_images and len(arg) != 1: + raise ValueError( + f"Batchable arguments should either be singular or as many as the full batch ({self.num_output_images})." + ) if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)] else: diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py index 9966b7fa8..88eb28ff4 100644 --- a/shortfin/python/shortfin_apps/sd/components/messages.py +++ b/shortfin/python/shortfin_apps/sd/components/messages.py @@ -6,11 +6,15 @@ from enum import Enum +import logging + import shortfin as sf import shortfin.array as sfnp from .io_struct import GenerateReqInput +logger = logging.getLogger(__name__) + class InferencePhase(Enum): # Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays @@ -54,6 +58,8 @@ def __init__( image_array: sfnp.device_array | None = None, ): super().__init__() + self.print_debug = True + self.phases = {} self.phase = None self.height = height @@ -87,6 +93,7 @@ def __init__( self.image_array = image_array self.result_image = None + self.img_metadata = None self.done = sf.VoidFuture() @@ -96,8 +103,6 @@ def __init__( self.return_host_array: bool = True self.post_init() - print(self.phases) - print(self.phase) @staticmethod def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": @@ -114,8 +119,13 @@ def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": for item in gen_inputs: received = getattr(gen_req, item, None) if isinstance(received, list): - if index >= len(received): - rec_input = None + if index >= (len(received)): + if len(received) == 1: + rec_input = received[0] + else: + logging.error( + "Inputs in request must be singular or as many as the list of prompts." + ) else: rec_input = received[index] else: diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 7313dd447..d6cf71e48 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -11,6 +11,8 @@ from tqdm.auto import tqdm from pathlib import Path from PIL import Image +import io +import base64 import shortfin as sf import shortfin.array as sfnp @@ -48,15 +50,17 @@ def __init__( self.inference_modules: dict[str, sf.ProgramModule] = {} self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} self.inference_programs: dict[str, sf.Program] = {} - self.procs_per_device = 2 + self.procs_per_device = 1 self.workers = [] self.fibers = [] + self.locks = [] for idx, device in enumerate(self.sysman.ls.devices): for i in range(self.procs_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") fiber = sysman.ls.create_fiber(worker, devices=[device]) self.workers.append(worker) self.fibers.append(fiber) + self.locks.append(asyncio.Lock()) # Scope dependent objects. self.batcher = BatcherProcess(self) @@ -84,18 +88,19 @@ def load_inference_parameters( self.inference_parameters[component].append(p) def start(self): - for component in self.inference_modules: - component_modules = [ - sf.ProgramModule.parameter_provider( - self.sysman.ls, *self.inference_parameters.get(component, []) - ), - *self.inference_modules[component], - ] - self.inference_programs[component] = sf.Program( - modules=component_modules, - fiber=self.fibers[0], - trace_execution=False, - ) + for fiber in self.fibers: + for component in self.inference_modules: + component_modules = [ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters.get(component, []) + ), + *self.inference_modules[component], + ] + self.inference_programs[component] = sf.Program( + modules=component_modules, + devices=fiber.raw_devices, + trace_execution=False, + ) # TODO: export vmfbs with multiple batch size entrypoints @@ -175,7 +180,7 @@ async def _background_strober(self): while not self.batcher_infeed.closed: await asyncio.sleep( BatcherProcess.STROBE_SHORT_DELAY - if len(self.pending_preps) > 0 + if len(self.pending_requests) > 0 else BatcherProcess.STROBE_LONG_DELAY ) if self.strobe_enabled: @@ -192,6 +197,7 @@ async def run(self): self.strobes += 1 else: logger.error("Illegal message received by batcher: %r", item) + self.board_flights() self.strobe_enabled = True await strober_task @@ -207,7 +213,7 @@ def board_flights(self): batches = self.sort_pending() for idx in batches.keys(): - self.board(batches[idx]["reqs"]) + self.board(batches[idx]["reqs"], index=idx) def sort_pending(self): """Returns pending requests as sorted batches suitable for program invocations.""" @@ -233,11 +239,11 @@ def sort_pending(self): } return batches - def board(self, request_bundle): + def board(self, request_bundle, index): pending = request_bundle if len(pending) == 0: return - exec_process = InferenceExecutorProcess(self.service, 0) + exec_process = InferenceExecutorProcess(self.service, index) for req in pending: if len(exec_process.exec_requests) >= self.ideal_batch_size: break @@ -264,6 +270,7 @@ def __init__( ): super().__init__(fiber=service.fibers[index]) self.service = service + self.worker_index = index self.exec_requests: list[InferenceExecRequest] = [] async def run(self): @@ -273,22 +280,22 @@ async def run(self): if phase: if phase != req.phase: logger.error("Executor process recieved disjoint batch.") + phase = req.phase phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - device0 = self.fiber.device(0) - await device0 - - if phases[InferencePhase.PREPARE]["required"]: - await self._prepare(device=device0, requests=self.exec_requests) - if phases[InferencePhase.ENCODE]["required"]: - await self._encode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DENOISE]["required"]: - await self._denoise(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DECODE]["required"]: - await self._decode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.POSTPROCESS]["required"]: - await self._postprocess(device=device0, requests=self.exec_requests) + async with self.service.locks[self.worker_index]: + device0 = self.fiber.device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._encode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) for i in range(req_count): req = self.exec_requests[i] @@ -384,7 +391,7 @@ async def _encode(self, device, requests): "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) await device - pe, te = await fn(*clip_inputs) + pe, te = await fn(*clip_inputs, fiber=self.fiber) await device for i in range(req_bs): @@ -396,6 +403,7 @@ async def _encode(self, device, requests): async def _denoise(self, device, requests): req_bs = len(requests) + step_count = requests[0].steps cfg_mult = 2 if self.service.model_params.cfg_mode else 1 # Produce denoised latents entrypoints = self.service.inference_functions["denoise"] @@ -463,6 +471,12 @@ async def _denoise(self, device, requests): denoise_inputs["guidance_scale"].copy_from(gs_host) + num_steps = sfnp.device_array.for_device(device, [1], sfnp.sint64) + ns_host = num_steps.for_transfer() + with ns_host.map(write=True) as m: + ns_host.items = [step_count] + num_steps.copy_from(ns_host) + await device # Initialize scheduler. logger.info( @@ -470,25 +484,20 @@ async def _denoise(self, device, requests): fns["init"], "".join([f"\n 0: {latents_shape}"]), ) - (latents, time_ids, step_indexes, timesteps,) = await fns[ - "init" - ](denoise_inputs["sample"]) + (latents, time_ids, timesteps, sigmas) = await fns["init"]( + denoise_inputs["sample"], num_steps, fiber=self.fiber + ) + await device - ts_host = timesteps.for_transfer() - ts_host.copy_from(timesteps) for i, t in tqdm( - enumerate(ts_host.items), + enumerate(range(step_count)), ): step = sfnp.device_array.for_device(device, [1], sfnp.sint64) s_host = step.for_transfer() with s_host.map(write=True) as m: s_host.items = [i] step.copy_from(s_host) - scale_inputs = [ - latents, - step, - timesteps, - ] + scale_inputs = [latents, step, timesteps, sigmas] logger.info( "INVOKE %r: %s", fns["scale"], @@ -497,7 +506,9 @@ async def _denoise(self, device, requests): ), ) await device - latent_model_input, t = await fns["scale"](*scale_inputs) + latent_model_input, t, sigma, next_sigma = await fns["scale"]( + *scale_inputs, fiber=self.fiber + ) await device unet_inputs = [ @@ -514,17 +525,17 @@ async def _denoise(self, device, requests): "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]), ) await device - (noise_pred,) = await fns["unet"](*unet_inputs) + (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) await device - step_inputs = [noise_pred, t, latents] + step_inputs = [noise_pred, latents, sigma, next_sigma] logger.info( "INVOKE %r: %s", fns["step"], "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]), ) await device - (latent_model_output,) = await fns["step"](*step_inputs) + (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) latents.copy_from(latent_model_output) await device @@ -558,7 +569,7 @@ async def _decode(self, device, requests): await device # Decode the denoised latents. - (image,) = await fn(latents) + (image,) = await fn(latents, fiber=self.fiber) await device images_shape = [ @@ -575,6 +586,7 @@ async def _decode(self, device, requests): ] images_host = sfnp.device_array.for_host(device, images_shape, sfnp.float16) images_host.copy_from(image) + await device for idx, req in enumerate(requests): image_array = images_host.view(idx).items dtype = image_array.typecode @@ -591,5 +603,8 @@ async def _postprocess(self, device, requests): # TODO: reimpl with sfnp permuted = np.transpose(req.image_array, (0, 2, 3, 1))[0] cast_image = (permuted * 255).round().astype("uint8") - req.result_image = Image.fromarray(cast_image).tobytes() + image_bytes = Image.fromarray(cast_image).tobytes() + + image = base64.b64encode(image_bytes).decode("utf-8") + req.result_image = image return diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json new file mode 100644 index 000000000..0ded22888 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json @@ -0,0 +1,18 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 7.5, + 7.9 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json new file mode 100644 index 000000000..b59887b8f --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json @@ -0,0 +1,22 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a dog under the snow with brown eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 10, + 10, + 10, + 10 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json index 69d363f15..be94293ae 100644 --- a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json @@ -10,44 +10,23 @@ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo" ], "neg_prompt": [ - "Watermark, blurry, oversaturated, low resolution, pollution", - "Watermark, blurry, oversaturated, low resolution, pollution", - "Watermark, blurry, oversaturated, low resolution, pollution", - "Watermark, blurry, oversaturated, low resolution, pollution", - "Watermark, blurry, oversaturated, low resolution, pollution", - "Watermark, blurry, oversaturated, low resolution, pollution", - "Watermark, blurry, oversaturated, low resolution, pollution", "Watermark, blurry, oversaturated, low resolution, pollution" ], "height": [ - 1024, - 1024, - 1024, - 1024, - 1024, - 1024, - 1024, 1024 ], "width": [ - 1024, - 1024, - 1024, - 1024, - 1024, - 1024, - 1024, 1024 ], "steps": [ 20, + 30, + 40, + 50, 20, - 20, - 20, - 20, - 20, - 20, - 20 + 30, + 40, + 50 ], "guidance_scale": [ 7.5, @@ -60,14 +39,7 @@ 7.5 ], "seed": [ - 0, - 1000, - 2000, - 3000, - 4000, - 5000, - 602234764, - 159360312 + 0 ], "output_type": [ "base64" diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py index 25b5e2d99..94fae9659 100644 --- a/shortfin/python/shortfin_apps/sd/examples/send_request.py +++ b/shortfin/python/shortfin_apps/sd/examples/send_request.py @@ -1,16 +1,43 @@ import json import requests import argparse +import base64 from datetime import datetime as dt from PIL import Image +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [20], + "guidance_scale": [7.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def bytes_to_img(bytes, idx=0, width=1024, height=1024): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(bytes) + ) + image.save(f"shortfin_sd_output_{timestamp}_{idx}.png") + print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png") + def send_json_file(file_path): # Read the JSON file try: - with open(file_path, "r") as json_file: - data = json.load(json_file) + if file_path == "default": + data = sample_request + else: + with open(file_path, "r") as json_file: + data = json.load(json_file) except Exception as e: print(f"Error reading the JSON file: {e}") return @@ -21,23 +48,29 @@ def send_json_file(file_path): response.raise_for_status() # Raise an error for bad responses print("Saving response as image...") timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - - if isinstance(response.content, list): - for idx, item in enumerate(response.content): - image = Image.frombytes(item) - image.save(f"shortfin_sd_output_{timestamp}_{idx}.png") - elif isinstance(response.content, bytes): - image = Image.frombytes( - mode="RGB", size=(1024, 1024), data=response.content - ) - image.save(f"shortfin_sd_output_{timestamp}.png") + request = json.loads(response.request.body.decode("utf-8")) + for idx, item in enumerate(response.json()["images"]): + width = get_batched(request, "width", idx) + height = get_batched(request, "height", idx) + bytes_to_img(item.encode("utf-8"), idx, width, height) except requests.exceptions.RequestException as e: print(f"Error sending the request: {e}") +def get_batched(request, arg, idx): + if isinstance(request[arg], list): + if len(request[arg]) == 1: + indexed = request[arg][0] + else: + indexed = request[arg][idx] + else: + indexed = request[arg] + return indexed + + if __name__ == "__main__": p = argparse.ArgumentParser() - p.add_argument("file", type=str) + p.add_argument("--file", type=str, default="default") args = p.parse_args() send_json_file(args.file) diff --git a/shortfin/requirements-tests.txt b/shortfin/requirements-tests.txt index e1a48a6df..668023a1e 100644 --- a/shortfin/requirements-tests.txt +++ b/shortfin/requirements-tests.txt @@ -11,3 +11,7 @@ wheel # Deps needed for shortfin_apps.llm dataclasses-json tokenizers + +# Deps needed for shortfin_apps.sd +pillow +transformers diff --git a/shortfin/tests/apps/sd/components/tokenizer_test.py b/shortfin/tests/apps/sd/components/tokenizer_test.py new file mode 100644 index 000000000..05515ec30 --- /dev/null +++ b/shortfin/tests/apps/sd/components/tokenizer_test.py @@ -0,0 +1,55 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + + +@pytest.fixture +def clip_tokenizer(): + from shortfin_apps.sd.components.tokenizer import Tokenizer + + return Tokenizer.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", "tokenizer" + ) + + +def test_transformers_tokenizer(clip_tokenizer): + enc0 = clip_tokenizer.encode(["This is sequence 1", "Sequence 2"]) + e0 = enc0.input_ids[0, :10] + e1 = enc0.input_ids[1, :10] + assert e0.tolist() == [ + 49406, + 589, + 533, + 18833, + 272, + 49407, + 49407, + 49407, + 49407, + 49407, + ] + assert e1.tolist() == [ + 49406, + 18833, + 273, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + ] + + +def test_tokenizer_to_array(cpu_fiber, clip_tokenizer): + batch_seq_len = 64 + encs = clip_tokenizer.encode(["This is sequence 1", "Sequence 2"]) + ary = clip_tokenizer.encodings_to_array(cpu_fiber.device(0), encs, batch_seq_len) + print(ary) + assert ary.view(0).items.tolist()[:5] == [49406, 589, 533, 18833, 272] + assert ary.view(1).items.tolist()[:5] == [49406, 18833, 273, 49407, 49407] diff --git a/shortfin/tests/apps/sd/conftest.py b/shortfin/tests/apps/sd/conftest.py new file mode 100644 index 000000000..1a08d9b4b --- /dev/null +++ b/shortfin/tests/apps/sd/conftest.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + +from shortfin.support.deps import ShortfinDepNotFoundError + + +@pytest.fixture(autouse=True) +def require_deps(): + try: + import shortfin_apps.sd + except ShortfinDepNotFoundError as e: + pytest.skip(f"Dep not available: {e}") diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py new file mode 100644 index 000000000..6ded547aa --- /dev/null +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -0,0 +1,201 @@ +import json +import requests +import time +import base64 +import pytest +import subprocess +import os +import socket +import sys +from contextlib import closing + +from datetime import datetime as dt +from PIL import Image + +BATCH_SIZES = [1] + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [20], + "guidance_scale": [7.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def sd_artifacts(target: str = "gfx942"): + return { + "model_config": "sdxl_config_i8.json", + "clip_vmfb": f"stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_{target}.vmfb", + "scheduler_vmfb": f"stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_{target}.vmfb", + "unet_vmfb": f"stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_{target}.vmfb", + "vae_vmfb": f"stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_{target}.vmfb", + "clip_params": "clip_splat_fp16.irpa", + "unet_params": "punet_splat_i8.irpa", + "vae_params": "vae_splat_fp16.irpa", + } + + +cache = os.path.abspath("./tmp/sharktank/sd/") + + +@pytest.fixture(scope="module") +def sd_server(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + # Download model if it doesn't exist + vmfbs_bucket = "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/" + weights_bucket = ( + "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/" + ) + configs_bucket = ( + "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/configs/" + ) + for artifact, path in sd_artifacts().items(): + if "vmfb" in artifact: + bucket = vmfbs_bucket + elif "params" in artifact: + bucket = weights_bucket + else: + bucket = configs_bucket + address = bucket + path + local_file = os.path.join(cache, path) + if not os.path.exists(local_file): + print("Downloading artifact from " + address) + r = requests.get(address, allow_redirects=True) + with open(local_file, "wb") as lf: + lf.write(r.content) + # Start the server + srv_args = [ + "python", + "-m", + "shortfin_apps.sd.server", + ] + for arg in sd_artifacts().keys(): + artifact_arg = f"--{arg}={cache}/{sd_artifacts()[arg]}" + srv_args.extend([artifact_arg]) + runner = ServerRunner(srv_args) + # Wait for server to start + time.sleep(5) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.mark.system("amdgpu") +def test_sd_server(sd_server): + imgs, status_code = send_json_file(sd_server.url) + assert len(imgs) == 1 + assert status_code == 200 + + +class ServerRunner: + def __init__(self, args): + port = str(find_free_port()) + self.url = "http://0.0.0.0:" + port + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + env["HIP_VISIBLE_DEVICES"] = "0" + self.process = subprocess.Popen( + [ + *args, + "--port=" + port, + "--device=amdgpu", + ], + env=env, + # TODO: Have a more robust way of forking a subprocess. + stdout=sys.stdout, + stderr=sys.stderr, + ) + print(self.process.args) + self._wait_for_ready() + + def _wait_for_ready(self): + start = time.time() + while True: + time.sleep(2) + try: + if requests.get(f"{self.url}/health").status_code == 200: + return + except Exception as e: + if self.process.errors is not None: + raise RuntimeError("API server processs terminated") from e + time.sleep(1.0) + if (time.time() - start) > 30: + raise RuntimeError("Timeout waiting for server start") + + def __del__(self): + try: + process = self.process + except AttributeError: + pass + else: + process.terminate() + process.wait() + + +def bytes_to_img(bytes, idx=0, width=1024, height=1024): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(bytes) + ) + return image + + +def send_json_file(url="http://0.0.0.0:8000"): + # Read the JSON file + data = sample_request + imgs = [] + # Send the data to the /generate endpoint + try: + response = requests.post(url + "/generate", json=data) + response.raise_for_status() # Raise an error for bad responses + request = json.loads(response.request.body.decode("utf-8")) + + for idx, item in enumerate(response.json()["images"]): + width = ( + request["width"][idx] + if isinstance(request["height"], list) + else request["height"] + ) + height = ( + request["height"][idx] + if isinstance(request["height"], list) + else request["height"] + ) + img = bytes_to_img(item.encode("utf-8"), idx, width, height) + imgs.append(img) + + except requests.exceptions.RequestException as e: + print(f"Error sending the request: {e}") + + return imgs, response.status_code + + +def find_free_port(): + """This tries to find a free port to run a server on for the test. + + Race conditions are possible - the port can be acquired between when this + runs and when the server starts. + + https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("localhost", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def test_placeholder(): + # Here in case this pytest is invoked via CPU CI and no tests are run. + pass