diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 9c4ee3b3e..4808cad08 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -38,11 +38,11 @@ cd shortfin/ The server will prepare runtime artifacts for you. ``` -python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 +python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile ``` - Run with splat(empty) weights: ``` -python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat +python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile ``` - Run a request in a separate shell: ``` diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index a83b63e48..1f9d0c2ee 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -1,7 +1,9 @@ from iree.build import * +from iree.build.executor import FileNamespace import itertools import os import shortfin.array as sfnp +import copy from shortfin_apps.sd.components.config_struct import ModelParams @@ -25,23 +27,33 @@ ) -def get_mlir_filenames(model_params: ModelParams): +def filter_by_model(filenames, model): + if not model: + return filenames + filtered = [] + for i in filenames: + if model.lower() in i.lower(): + filtered.extend([i]) + return filtered + + +def get_mlir_filenames(model_params: ModelParams, model=None): mlir_filenames = [] file_stems = get_file_stems(model_params) for stem in file_stems: mlir_filenames.extend([stem + ".mlir"]) - return mlir_filenames + return filter_by_model(mlir_filenames, model) -def get_vmfb_filenames(model_params: ModelParams, target: str = "gfx942"): +def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"): vmfb_filenames = [] file_stems = get_file_stems(model_params) for stem in file_stems: vmfb_filenames.extend([stem + "_" + target + ".vmfb"]) - return vmfb_filenames + return filter_by_model(vmfb_filenames, model) -def get_params_filenames(model_params: ModelParams, splat: bool): +def get_params_filenames(model_params: ModelParams, model=None, splat: bool = False): params_filenames = [] base = ( "stable_diffusion_xl_base_1_0" @@ -69,7 +81,7 @@ def get_params_filenames(model_params: ModelParams, splat: bool): params_filenames.extend( [base + "_" + mod + "_dataset_" + mod_precs[idx] + ".irpa"] ) - return params_filenames + return filter_by_model(params_filenames, model) def get_file_stems(model_params: ModelParams): @@ -142,21 +154,38 @@ def needs_update(ctx): return False -def needs_file(filename, ctx): - out_file = ctx.allocate_file(filename).get_fs_path() +def needs_file(filename, ctx, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() if os.path.exists(out_file): needed = False else: - filekey = f"{ctx.path}/{filename}" + name_path = "bin" if namespace == FileNamespace.BIN else "" + if name_path: + filename = os.path.join(name_path, filename) + filekey = os.path.join(ctx.path, filename) ctx.executor.all[filekey] = None needed = True return needed +def needs_compile(filename, target, ctx): + device = "amdgpu" if "gfx" in target else "llvmcpu" + vmfb_name = f"{filename}_{device}-{target}.vmfb" + namespace = FileNamespace.BIN + return needs_file(vmfb_name, ctx, namespace) + + +def get_cached_vmfb(filename, target, ctx): + device = "amdgpu" if "gfx" in target else "llvmcpu" + vmfb_name = f"{filename}_{device}-{target}.vmfb" + namespace = FileNamespace.BIN + return ctx.file(vmfb_name) + + @entrypoint(description="Retreives a set of SDXL submodels.") def sdxl( model_json=cl_arg( - "model_json", + "model-json", default=default_config_json, help="Local config filepath", ), @@ -168,6 +197,12 @@ def sdxl( splat=cl_arg( "splat", default=False, type=str, help="Download empty weights (for testing)" ), + build_preference=cl_arg( + "build-preference", + default="precompiled", + help="Sets preference for artifact generation method: [compile, precompiled]", + ), + model=cl_arg("model", type=str, help="Submodel to fetch/compile for."), ): model_params = ModelParams.load_json(model_json) ctx = executor.BuildContext.current() @@ -176,24 +211,38 @@ def sdxl( mlir_bucket = SDXL_BUCKET + "mlir/" vmfb_bucket = SDXL_BUCKET + "vmfbs/" - mlir_filenames = get_mlir_filenames(model_params) + mlir_filenames = get_mlir_filenames(model_params, model) mlir_urls = get_url_map(mlir_filenames, mlir_bucket) for f, url in mlir_urls.items(): if update or needs_file(f, ctx): fetch_http(name=f, url=url) - vmfb_filenames = get_vmfb_filenames(model_params, target=target) + vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target) vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket) - for f, url in vmfb_urls.items(): - if update or needs_file(f, ctx): - fetch_http(name=f, url=url) - params_filenames = get_params_filenames(model_params, splat) + if build_preference == "compile": + for idx, f in enumerate(copy.deepcopy(vmfb_filenames)): + # We return .vmfb file stems for the compile builder. + file_stem = "_".join(f.split("_")[:-1]) + if needs_compile(file_stem, target, ctx): + for mlirname in mlir_filenames: + if file_stem in mlirname: + mlir_source = mlirname + break + obj = compile(name=file_stem, source=mlir_source) + vmfb_filenames[idx] = obj[0] + else: + vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx) + else: + for f, url in vmfb_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + params_filenames = get_params_filenames(model_params, model=model, splat=splat) params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) for f, url in params_urls.items(): out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) - filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] return filenames diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt b/shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt new file mode 100644 index 000000000..731bc6da0 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt @@ -0,0 +1,23 @@ +all +--iree-hal-target-backends=rocm +--iree-hip-target=gfx942 +--iree-execution-model=async-external +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +--iree-global-opt-propagate-transposes=1 +--iree-opt-const-eval=0 +--iree-opt-outer-dim-concat=1 +--iree-opt-aggressively-propagate-transposes=1 +--iree-dispatch-creation-enable-aggressive-fusion +--iree-hal-force-indirect-command-buffers +--iree-codegen-llvmgpu-use-vector-distribution=1 +--iree-llvmgpu-enable-prefetch=1 +--iree-codegen-gpu-native-math-precision=1 +--iree-hip-legacy-sync=0 +--iree-opt-data-tiling=0 +--iree-vm-target-truncate-unsupported-floats +clip +unet +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 +vae +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 +scheduler diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 177361c06..849337900 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -12,6 +12,8 @@ import sys import os import io +import copy +import subprocess from iree.build import * @@ -101,6 +103,7 @@ def configure(args) -> SystemManager: tokenizers=tokenizers, model_params=model_params, fibers_per_device=args.fibers_per_device, + workers_per_device=args.workers_per_device, prog_isolation=args.isolation, show_progress=args.show_progress, trace_execution=args.trace_execution, @@ -118,18 +121,44 @@ def configure(args) -> SystemManager: def get_modules(args): vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} params = {"clip": [], "unet": [], "vae": []} - mod = load_build_module(os.path.join(THIS_DIR, "components", "builders.py")) - out_file = io.StringIO() - iree_build_main( - mod, - args=[ - f"--model_json={args.model_config}", + model_flags = copy.deepcopy(vmfbs) + model_flags["all"] = args.compile_flags + + if args.flagfile: + with open(args.flagfile, "r") as f: + contents = [line.rstrip() for line in f] + flagged_model = "all" + for elem in contents: + match = [keyw in elem for keyw in model_flags.keys()] + if any(match): + flagged_model = elem + else: + model_flags[flagged_model].extend([elem]) + + filenames = [] + for modelname in vmfbs.keys(): + ireec_args = model_flags["all"] + model_flags[modelname] + builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "builders.py"), + f"--model-json={args.model_config}", f"--target={args.target}", f"--splat={args.splat}", - ], - stdout=out_file, - ) - filenames = out_file.getvalue().strip().split("\n") + f"--build-preference={args.build_preference}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--iree-hal-target-device={args.device}", + f"--iree-hip-target={args.target}", + f"--iree-compile-extra-args={" ".join(ireec_args)}", + ] + print("BUILDER INPUT:\n", " \ \n ".join(builder_args)) + output = subprocess.check_output(builder_args).decode() + print("OUTPUT:", output) + + output_paths = output.splitlines() + filenames.extend(output_paths) for name in filenames: for key in vmfbs.keys(): if key in name.lower(): @@ -165,6 +194,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): type=str, required=False, default="gfx942", + choices=["gfx942", "gfx1100"], help="Primary inferencing device LLVM target arch.", ) parser.add_argument( @@ -190,6 +220,12 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): required=True, help="Path to the model config file", ) + parser.add_argument( + "--workers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) parser.add_argument( "--fibers_per_device", type=int, @@ -221,6 +257,31 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): action="store_true", help="Use splat (empty) parameter files, usually for testing.", ) + parser.add_argument( + "--build_preference", + type=str, + choices=["compile", "precompiled"], + default="precompiled", + help="Specify preference for builder artifact generation.", + ) + parser.add_argument( + "--compile_flags", + type=str, + nargs="*", + default=[], + help="extra compile flags for all compile actions. For fine-grained control, use flagfiles.", + ) + parser.add_argument( + "--flagfile", + type=Path, + help="Path to a flagfile to use for SDXL. If not specified, will use latest flagfile from azure.", + ) + parser.add_argument( + "--artifacts_dir", + type=str, + default="", + help="Path to local artifacts cache.", + ) log_levels = { "info": logging.INFO, "debug": logging.DEBUG, diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index 8fe8de5b3..cab8ecab2 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -135,7 +135,7 @@ def test_sd_server_bs8_dense_fpd8(sd_server_fpd8): assert status_code == 200 -@pytest.mark.slow +@pytest.mark.skip @pytest.mark.system("amdgpu") def test_sd_server_bs64_dense_fpd8(sd_server_fpd8): imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=64) @@ -143,7 +143,7 @@ def test_sd_server_bs64_dense_fpd8(sd_server_fpd8): assert status_code == 200 -@pytest.mark.slow +@pytest.mark.skip @pytest.mark.xfail(reason="Unexpectedly large client batch.") @pytest.mark.system("amdgpu") def test_sd_server_bs512_dense_fpd8(sd_server_fpd8):