Skip to content

Commit

Permalink
(shortfin-sd) Adds compile action to sdxl model builder. (#427)
Browse files Browse the repository at this point in the history
A few new builder-related changes in this PR:
- Adds `--build_preference` to server arguments (choices `["compile",
"precompiled"]`)
- Isolates builder invocations by submodel for finer grained compilation
flag control
- Adds a flagfile for GFX942 that contains optimal flag settings for
iree-compile by submodel.
 - Runs builder in a subprocess to prevent compile arg confusion
  • Loading branch information
monorimet authored Nov 5, 2024
1 parent 3dcca1f commit 8ed90d2
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 31 deletions.
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ cd shortfin/
The server will prepare runtime artifacts for you.

```
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
```
- Run with splat(empty) weights:
```
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
```
- Run a request in a separate shell:
```
Expand Down
83 changes: 66 additions & 17 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from iree.build import *
from iree.build.executor import FileNamespace
import itertools
import os
import shortfin.array as sfnp
import copy

from shortfin_apps.sd.components.config_struct import ModelParams

Expand All @@ -25,23 +27,33 @@
)


def get_mlir_filenames(model_params: ModelParams):
def filter_by_model(filenames, model):
if not model:
return filenames
filtered = []
for i in filenames:
if model.lower() in i.lower():
filtered.extend([i])
return filtered


def get_mlir_filenames(model_params: ModelParams, model=None):
mlir_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
mlir_filenames.extend([stem + ".mlir"])
return mlir_filenames
return filter_by_model(mlir_filenames, model)


def get_vmfb_filenames(model_params: ModelParams, target: str = "gfx942"):
def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"):
vmfb_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
vmfb_filenames.extend([stem + "_" + target + ".vmfb"])
return vmfb_filenames
return filter_by_model(vmfb_filenames, model)


def get_params_filenames(model_params: ModelParams, splat: bool):
def get_params_filenames(model_params: ModelParams, model=None, splat: bool = False):
params_filenames = []
base = (
"stable_diffusion_xl_base_1_0"
Expand Down Expand Up @@ -69,7 +81,7 @@ def get_params_filenames(model_params: ModelParams, splat: bool):
params_filenames.extend(
[base + "_" + mod + "_dataset_" + mod_precs[idx] + ".irpa"]
)
return params_filenames
return filter_by_model(params_filenames, model)


def get_file_stems(model_params: ModelParams):
Expand Down Expand Up @@ -142,21 +154,38 @@ def needs_update(ctx):
return False


def needs_file(filename, ctx):
out_file = ctx.allocate_file(filename).get_fs_path()
def needs_file(filename, ctx, namespace=FileNamespace.GEN):
out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
filekey = f"{ctx.path}/{filename}"
name_path = "bin" if namespace == FileNamespace.BIN else ""
if name_path:
filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
return needed


def needs_compile(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
return needs_file(vmfb_name, ctx, namespace)


def get_cached_vmfb(filename, target, ctx):
device = "amdgpu" if "gfx" in target else "llvmcpu"
vmfb_name = f"{filename}_{device}-{target}.vmfb"
namespace = FileNamespace.BIN
return ctx.file(vmfb_name)


@entrypoint(description="Retreives a set of SDXL submodels.")
def sdxl(
model_json=cl_arg(
"model_json",
"model-json",
default=default_config_json,
help="Local config filepath",
),
Expand All @@ -168,6 +197,12 @@ def sdxl(
splat=cl_arg(
"splat", default=False, type=str, help="Download empty weights (for testing)"
),
build_preference=cl_arg(
"build-preference",
default="precompiled",
help="Sets preference for artifact generation method: [compile, precompiled]",
),
model=cl_arg("model", type=str, help="Submodel to fetch/compile for."),
):
model_params = ModelParams.load_json(model_json)
ctx = executor.BuildContext.current()
Expand All @@ -176,24 +211,38 @@ def sdxl(
mlir_bucket = SDXL_BUCKET + "mlir/"
vmfb_bucket = SDXL_BUCKET + "vmfbs/"

mlir_filenames = get_mlir_filenames(model_params)
mlir_filenames = get_mlir_filenames(model_params, model)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
for f, url in mlir_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

vmfb_filenames = get_vmfb_filenames(model_params, target=target)
vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target)
vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket)
for f, url in vmfb_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)
params_filenames = get_params_filenames(model_params, splat)
if build_preference == "compile":
for idx, f in enumerate(copy.deepcopy(vmfb_filenames)):
# We return .vmfb file stems for the compile builder.
file_stem = "_".join(f.split("_")[:-1])
if needs_compile(file_stem, target, ctx):
for mlirname in mlir_filenames:
if file_stem in mlirname:
mlir_source = mlirname
break
obj = compile(name=file_stem, source=mlir_source)
vmfb_filenames[idx] = obj[0]
else:
vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx)
else:
for f, url in vmfb_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

params_filenames = get_params_filenames(model_params, model=model, splat=splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames

Expand Down
23 changes: 23 additions & 0 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
all
--iree-hal-target-backends=rocm
--iree-hip-target=gfx942
--iree-execution-model=async-external
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'
--iree-global-opt-propagate-transposes=1
--iree-opt-const-eval=0
--iree-opt-outer-dim-concat=1
--iree-opt-aggressively-propagate-transposes=1
--iree-dispatch-creation-enable-aggressive-fusion
--iree-hal-force-indirect-command-buffers
--iree-codegen-llvmgpu-use-vector-distribution=1
--iree-llvmgpu-enable-prefetch=1
--iree-codegen-gpu-native-math-precision=1
--iree-hip-legacy-sync=0
--iree-opt-data-tiling=0
--iree-vm-target-truncate-unsupported-floats
clip
unet
--iree-dispatch-creation-enable-fuse-horizontal-contractions=1
vae
--iree-dispatch-creation-enable-fuse-horizontal-contractions=1
scheduler
81 changes: 71 additions & 10 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import sys
import os
import io
import copy
import subprocess

from iree.build import *

Expand Down Expand Up @@ -101,6 +103,7 @@ def configure(args) -> SystemManager:
tokenizers=tokenizers,
model_params=model_params,
fibers_per_device=args.fibers_per_device,
workers_per_device=args.workers_per_device,
prog_isolation=args.isolation,
show_progress=args.show_progress,
trace_execution=args.trace_execution,
Expand All @@ -118,18 +121,44 @@ def configure(args) -> SystemManager:
def get_modules(args):
vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []}
params = {"clip": [], "unet": [], "vae": []}
mod = load_build_module(os.path.join(THIS_DIR, "components", "builders.py"))
out_file = io.StringIO()
iree_build_main(
mod,
args=[
f"--model_json={args.model_config}",
model_flags = copy.deepcopy(vmfbs)
model_flags["all"] = args.compile_flags

if args.flagfile:
with open(args.flagfile, "r") as f:
contents = [line.rstrip() for line in f]
flagged_model = "all"
for elem in contents:
match = [keyw in elem for keyw in model_flags.keys()]
if any(match):
flagged_model = elem
else:
model_flags[flagged_model].extend([elem])

filenames = []
for modelname in vmfbs.keys():
ireec_args = model_flags["all"] + model_flags[modelname]
builder_args = [
sys.executable,
"-m",
"iree.build",
os.path.join(THIS_DIR, "components", "builders.py"),
f"--model-json={args.model_config}",
f"--target={args.target}",
f"--splat={args.splat}",
],
stdout=out_file,
)
filenames = out_file.getvalue().strip().split("\n")
f"--build-preference={args.build_preference}",
f"--output-dir={args.artifacts_dir}",
f"--model={modelname}",
f"--iree-hal-target-device={args.device}",
f"--iree-hip-target={args.target}",
f"--iree-compile-extra-args={" ".join(ireec_args)}",
]
print("BUILDER INPUT:\n", " \ \n ".join(builder_args))
output = subprocess.check_output(builder_args).decode()
print("OUTPUT:", output)

output_paths = output.splitlines()
filenames.extend(output_paths)
for name in filenames:
for key in vmfbs.keys():
if key in name.lower():
Expand Down Expand Up @@ -165,6 +194,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
type=str,
required=False,
default="gfx942",
choices=["gfx942", "gfx1100"],
help="Primary inferencing device LLVM target arch.",
)
parser.add_argument(
Expand All @@ -190,6 +220,12 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
required=True,
help="Path to the model config file",
)
parser.add_argument(
"--workers_per_device",
type=int,
default=1,
help="Concurrency control -- how many fibers are created per device to run inference.",
)
parser.add_argument(
"--fibers_per_device",
type=int,
Expand Down Expand Up @@ -221,6 +257,31 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
action="store_true",
help="Use splat (empty) parameter files, usually for testing.",
)
parser.add_argument(
"--build_preference",
type=str,
choices=["compile", "precompiled"],
default="precompiled",
help="Specify preference for builder artifact generation.",
)
parser.add_argument(
"--compile_flags",
type=str,
nargs="*",
default=[],
help="extra compile flags for all compile actions. For fine-grained control, use flagfiles.",
)
parser.add_argument(
"--flagfile",
type=Path,
help="Path to a flagfile to use for SDXL. If not specified, will use latest flagfile from azure.",
)
parser.add_argument(
"--artifacts_dir",
type=str,
default="",
help="Path to local artifacts cache.",
)
log_levels = {
"info": logging.INFO,
"debug": logging.DEBUG,
Expand Down
4 changes: 2 additions & 2 deletions shortfin/tests/apps/sd/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def test_sd_server_bs8_dense_fpd8(sd_server_fpd8):
assert status_code == 200


@pytest.mark.slow
@pytest.mark.skip
@pytest.mark.system("amdgpu")
def test_sd_server_bs64_dense_fpd8(sd_server_fpd8):
imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=64)
assert len(imgs) == 64
assert status_code == 200


@pytest.mark.slow
@pytest.mark.skip
@pytest.mark.xfail(reason="Unexpectedly large client batch.")
@pytest.mark.system("amdgpu")
def test_sd_server_bs512_dense_fpd8(sd_server_fpd8):
Expand Down

0 comments on commit 8ed90d2

Please sign in to comment.