Skip to content

Commit

Permalink
SDXL implementation II
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 16, 2024
1 parent fca3290 commit 29871a1
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 189 deletions.
7 changes: 7 additions & 0 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ This directory contains a SD inference server, CLI and support components.

## Quick start

Currently, we use the diffusers library for SD schedulers.
In your shortfin environment,
```
pip install diffusers@git+https://github.com/nod-ai/[email protected]
pip install transformers
```
```
python -m shortfin_apps.sd.server --help
```
10 changes: 10 additions & 0 deletions shortfin/python/shortfin_apps/sd/_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@

from shortfin.support.deps import ShortfinDepNotFoundError

try:
import diffusers
except ModuleNotFoundError as e:
raise ShortfinDepNotFoundError(__name__, "diffusers") from e

try:
import transformers
except ModuleNotFoundError as e:
raise ShortfinDepNotFoundError(__name__, "diffusers") from e

try:
import tokenizers
except ModuleNotFoundError as e:
Expand Down
3 changes: 3 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class ModelParams:
# Maximum length of prompt sequence.
max_seq_len: int

# Channel dim of latents.
num_latents_channels: int

# Batch sizes that each stage is compiled for. These are expected to be
# functions exported from the model with suffixes of "_bs{batch_size}". Must
# be in ascending order.
Expand Down
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .io_struct import GenerateReqInput
from .messages import InferenceExecRequest, InferencePhase
from .service import GenerateService
from .tokenizer import Encoding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,8 +124,9 @@ async def run(self):
out = io.BytesIO()
result_images = [p.result_image for p in gen_processes]
for idx, result_image in enumerate(result_images):
out.write(f"generated image #{idx}")
out.write(result_image)
# TODO: save or return images
logging.debug("Wrote images as bytes to response.")
self.responder.send_response(out.getvalue())
finally:
self.responder.ensure_response()
5 changes: 2 additions & 3 deletions shortfin/python/shortfin_apps/sd/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class GenerateReqInput:
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# Negative token ids: only used in place of negative prompt.
neg_input_ids: Optional[Union[List[List[int]], List[int]]] = None
# Noisy latents, optionally specified for advanced workflows / inference comparisons
latents: Optional[Union[List[sfnp.device_array], sfnp.device_array]] = None
# The sampling parameters.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# Output image format. Defaults to base64. One string ("PIL", "base64")
Expand All @@ -57,8 +55,9 @@ def post_init(self):
continue
if not isinstance(i, list):
raise ValueError("Text inputs should be strings or lists.")
if prev_input_len and not (prev_input_len == len[i]):
if prev_input_len and not (prev_input_len == len(i)):
raise ValueError("Positive, Negative text inputs should be same length")
self.num_output_images = len(i)
prev_input_len = len(i)
if not self.num_output_images:
self.num_output_images = (
Expand Down
13 changes: 6 additions & 7 deletions shortfin/python/shortfin_apps/sd/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ def __init__(
steps: int | None = None,
guidance_scale: float | sfnp.device_array | None = None,
seed: int | None = None,
input_ids: sfnp.device_array | None = None,
neg_input_ids: sfnp.device_array | None = None,
input_ids: list[list[int]] | None = None,
sample: sfnp.device_array | None = None,
prompt_embeds: sfnp.device_array | None = None,
neg_embeds: sfnp.device_array | None = None,
add_text_embeds: sfnp.device_array | None = None,
timesteps: sfnp.device_array | None = None,
time_ids: sfnp.device_array | None = None,
denoised_latents: sfnp.device_array | None = None,
Expand All @@ -64,21 +63,21 @@ def __init__(
self.neg_prompt = neg_prompt
self.height = height
self.width = width
self.steps = steps
self.guidance_scale = guidance_scale
self.seed = seed

# Encode phase.
# This is a list of sequenced positive and negative token ids and pooler token ids.
self.input_ids = input_ids
self.neg_input_ids = neg_input_ids

# Denoise phase.
self.prompt_embeds = prompt_embeds
self.neg_embeds = neg_embeds
self.add_text_embeds = add_text_embeds
self.sample = sample
# guidance scale at denoise phase is a device array
self.steps = steps
self.timesteps = timesteps
self.time_ids = time_ids
self.guidance_scale = guidance_scale

# Decode phase.
self.denoised_latents = denoised_latents
Expand Down
Loading

0 comments on commit 29871a1

Please sign in to comment.