Skip to content

Commit

Permalink
Merge branch 'main' into pip-pre
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre authored Nov 14, 2024
2 parents 677565b + 5dd512a commit 9d68526
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 65 deletions.
2 changes: 1 addition & 1 deletion shortfin/python/shortfin/support/logging_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self):
native_handler.setFormatter(NativeFormatter())

# TODO: Source from env vars.
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.WARNING)
logger.addHandler(native_handler)


Expand Down
25 changes: 6 additions & 19 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,25 @@ In your shortfin environment,
pip install transformers
pip install dataclasses-json
pip install pillow
pip install shark-ai
```
```
python -m shortfin_apps.sd.server --help
```

## Run tests

- From SHARK-Platform/shortfin:
```
pytest --system=amdgpu -k "sd"
```
The tests run with splat weights.


## Run on MI300x

- Follow quick start
# Run on MI300x
The server will prepare runtime artifacts for you.

- Navigate to shortfin/ (only necessary if you're using following CLI exactly.)
```
cd shortfin/
```
- Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples):
By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands.

The server will prepare runtime artifacts for you.
You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`.

```
python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single"
```

- Run a CLI client in a separate shell:
```
python -m shortfin_apps.sd.simple_client --interactive --save
python -m shortfin_apps.sd.simple_client --interactive
```
12 changes: 0 additions & 12 deletions shortfin/python/shortfin_apps/sd/components/config_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@
import shortfin.array as sfnp
import copy

from shortfin_apps.sd.components.config_struct import ModelParams

this_dir = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(this_dir)

dtype_to_filetag = {
sfnp.float16: "fp16",
sfnp.float32: "fp32",
sfnp.int8: "i8",
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11132024"
SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/"

Expand Down
2 changes: 1 addition & 1 deletion shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser.add_argument(
"--tuning_spec",
type=str,
default="",
default=None,
help="Path to transform dialect spec if compiling an executable with tunings.",
)
parser.add_argument(
Expand Down
61 changes: 29 additions & 32 deletions shortfin/python/shortfin_apps/sd/simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
}


def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"):
def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024):
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
image = Image.frombytes(
mode="RGB", size=(width, height), data=base64.b64decode(bytes)
Expand All @@ -46,6 +46,7 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"):

def get_batched(request, arg, idx):
if isinstance(request[arg], list):
# some args are broadcasted to each prompt, hence overriding idx for single-item entries
if len(request[arg]) == 1:
indexed = request[arg][0]
else:
Expand All @@ -56,34 +57,30 @@ def get_batched(request, arg, idx):


async def send_request(session, rep, args, data):
try:
print("Sending request batch #", rep)
url = f"http://0.0.0.0:{args.port}/generate"
start = time.time()
async with session.post(url, json=data) as response:
end = time.time()
# Check if the response was successful
if response.status == 200:
response.raise_for_status() # Raise an error for bad responses
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
res_json = await response.json(content_type=None)
if args.save:
for idx, item in enumerate(res_json["images"]):
width = get_batched(data, "width", idx)
height = get_batched(data, "height", idx)
print("Saving response as image...")
bytes_to_img(
item.encode("utf-8"), idx, width, height, args.outputdir
)
latency = end - start
print("Responses processed.")
return latency, len(data["prompt"])
else:
print(f"Error: Received {response.status} from server")
raise Exception
except Exception as e:
print(f"Request failed: {e}")
raise Exception
print("Sending request batch #", rep)
url = f"http://0.0.0.0:{args.port}/generate"
start = time.time()
async with session.post(url, json=data) as response:
end = time.time()
# Check if the response was successful
if response.status == 200:
response.raise_for_status() # Raise an error for bad responses
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
res_json = await response.json(content_type=None)
if args.save:
for idx, item in enumerate(res_json["images"]):
width = get_batched(data, "width", idx)
height = get_batched(data, "height", idx)
print("Saving response as image...")
bytes_to_img(
item.encode("utf-8"), args.outputdir, idx, width, height
)
latency = end - start
print("Responses processed.")
return latency, len(data["prompt"])
else:
print(f"Error: Received {response.status} from server")
raise Exception


async def static(args):
Expand All @@ -94,7 +91,7 @@ async def static(args):
sample_counts = []
# Read the JSON file if supplied. Otherwise, get user input.
try:
if args.file == "default":
if not args.file:
data = sample_request
else:
with open(args.file, "r") as json_file:
Expand Down Expand Up @@ -135,7 +132,7 @@ async def interactive(args):
sample_counts = []
# Read the JSON file if supplied. Otherwise, get user input.
try:
if args.file == "default":
if not args.file:
data = sample_request
else:
with open(args.file, "r") as json_file:
Expand Down Expand Up @@ -185,7 +182,7 @@ def main(argv):
p.add_argument(
"--file",
type=str,
default="default",
default=None,
help="A non-default request to send to the server.",
)
p.add_argument(
Expand Down

0 comments on commit 9d68526

Please sign in to comment.