Skip to content

Commit

Permalink
Move app scripts into chart repo
Browse files Browse the repository at this point in the history
  • Loading branch information
sd109 committed Nov 12, 2024
1 parent 6e61451 commit 07a86bf
Show file tree
Hide file tree
Showing 11 changed files with 403 additions and 2 deletions.
2 changes: 1 addition & 1 deletion charts/flux-image-gen/templates/tests/gradio-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ spec:
image: "{{ $.Values.image.repository }}:{{ $.Values.image.tag | default $.Chart.AppVersion }}"
command:
- python
- stackhpc-app/test_client.py
- test_client.py
env:
- name: GRADIO_HOST
value: {{ printf "http://%s-ui.%s.svc:%v" (include "flux-image-gen.fullname" .) .Release.Namespace .Values.ui.service.port }}
Expand Down
2 changes: 1 addition & 1 deletion charts/flux-image-gen/templates/ui/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ spec:
imagePullPolicy: {{ .Values.image.pullPolicy }}
command:
- python
- stackhpc-app/gradio_ui.py
- gradio_ui.py
ports:
- name: http
containerPort: {{ .Values.ui.service.port }}
Expand Down
1 change: 1 addition & 0 deletions web-apps/flux-image-gen/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
output/
23 changes: 23 additions & 0 deletions web-apps/flux-image-gen/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
FROM python:3.11

# https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y


ARG DIR=flux-image-gen

COPY $DIR/requirements.txt requirements.txt
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh .
RUN bash purge-google-fonts.sh

WORKDIR /app

COPY $DIR/*.py .

COPY $DIR/gradio_config.yaml .

COPY $DIR/test-image.jpg .

ENTRYPOINT ["fastapi", "run", "api_server.py"]
66 changes: 66 additions & 0 deletions web-apps/flux-image-gen/api_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import io
import os
import sys
import torch

from fastapi import FastAPI
from fastapi.responses import Response, JSONResponse
from PIL import Image
from pydantic import BaseModel

from image_gen import FluxGenerator

# Detect if app is run using `fastapi dev ...`
DEV_MODE = sys.argv[1] == "dev"

app = FastAPI()

device = "cuda" if torch.cuda.is_available() else "cpu"
model = os.environ.get("FLUX_MODEL_NAME", "flux-schnell")
if not DEV_MODE:
print("Loading model", model)
generator = FluxGenerator(model, device, offload=False)


class ImageGenInput(BaseModel):
width: int
height: int
num_steps: int
guidance: float
seed: int
prompt: str
add_sampling_metadata: bool


@app.get("/model")
async def get_model():
return {"model": model}


@app.post("/generate")
async def generate_image(input: ImageGenInput):
if DEV_MODE:
# For quicker testing or when GPU hardware not available
fn = "test-image.jpg"
seed = "dev"
image = Image.open(fn)
# Uncomment to test error handling
# return JSONResponse({"error": {"message": "Dev mode error test", "seed": "not-so-random"}}, status_code=400)
else:
# Main image generation functionality
image, seed, msg = generator.generate_image(
input.width,
input.height,
input.num_steps,
input.guidance,
input.seed,
input.prompt,
add_sampling_metadata=input.add_sampling_metadata,
)
if not image:
return JSONResponse({"error": {"message": msg, "seed": seed}}, status_code=400)
# Convert image to bytes response
buffer = io.BytesIO()
image.save(buffer, format="jpeg")
bytes = buffer.getvalue()
return Response(bytes, media_type="image/jpeg", headers={"x-flux-seed": seed})
5 changes: 5 additions & 0 deletions web-apps/flux-image-gen/gradio_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
models:
- name: flux-schnell
address: http://localhost:8000
example_prompt: |
Yoda riding a skateboard.
127 changes: 127 additions & 0 deletions web-apps/flux-image-gen/gradio_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import io
import os
import httpx
import uuid
import pathlib
import yaml

import gradio as gr
from pydantic import BaseModel, HttpUrl
from PIL import Image, ExifTags
from typing import List
from urllib.parse import urljoin


class Model(BaseModel):
name: str
address: HttpUrl

class AppSettings(BaseModel):
models: List[Model]
example_prompt: str


settings_path = pathlib.Path("/etc/gradio-app/gradio_config.yaml")
if not settings_path.exists():
print("No settings overrides found at", settings_path)
settings_path = "./gradio_config.yaml"
print("Using settings from", settings_path)
with open(settings_path, "r") as file:
settings = AppSettings(**yaml.safe_load(file))
print("App config:", settings.model_dump())

MODELS = {m.name: m.address for m in settings.models}
MODEL_NAMES = list(MODELS.keys())

# Disable analytics for GDPR compliance
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

def save_image(model_name: str, prompt: str, seed: int, add_sampling_metadata: bool, image: Image.Image):
filename = f"output/gradio/{uuid.uuid4()}.jpg"
os.makedirs(os.path.dirname(filename), exist_ok=True)
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = model_name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
image.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
return filename


async def generate_image(
model_name: str,
width: int,
height: int,
num_steps: int,
guidance: float,
seed: int,
prompt: str,
add_sampling_metadata: bool,
):
url = urljoin(str(MODELS[model_name]), "/generate")
data = {
"width": width,
"height": height,
"num_steps": num_steps,
"guidance": guidance,
"seed": seed,
"prompt": prompt,
"add_sampling_metadata": add_sampling_metadata,
}
async with httpx.AsyncClient(timeout=60) as client:
try:
response = await client.post(url, json=data)
except httpx.ConnectError:
raise gr.Error("Model backend unavailable")
if response.status_code == 400:
data = response.json()
if "error" in data and "message" in data["error"]:
message = data["error"]["message"]
if "seed" in data["error"]:
message += f" (seed: {data['error']['seed']})"
raise gr.Error(message)
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
# Raise a generic error message to avoid leaking unwanted details
# Admin should consult API logs for more info
raise gr.Error(f"Backend error (HTTP {err.response.status_code})")
image = Image.open(io.BytesIO(response.content))
seed = response.headers.get("x-flux-seed", "unknown")
filename = save_image(model_name, prompt, seed, add_sampling_metadata, image)

return image, seed, filename, None


with gr.Blocks() as demo:
gr.Markdown("# Flux Image Generation Demo")

with gr.Row():
with gr.Column():
model = gr.Dropdown(MODEL_NAMES, value=MODEL_NAMES[0], label="Model", interactive=len(MODEL_NAMES) > 1)
prompt = gr.Textbox(label="Prompt", value=settings.example_prompt)

with gr.Accordion("Advanced Options", open=False):
# TODO: Make min/max slide values configurable
width = gr.Slider(128, 8192, 1360, step=16, label="Width")
height = gr.Slider(128, 8192, 768, step=16, label="Height")
num_steps = gr.Slider(1, 50, 4 if model.value == "flux-schnell" else 50, step=1, label="Number of steps")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not model.value == "flux-schnell")
seed = gr.Textbox("-1", label="Seed (-1 for random)")
add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)

generate_btn = gr.Button("Generate")

with gr.Column():
output_image = gr.Image(label="Generated Image")
seed_output = gr.Textbox(label="Used Seed")
warning_text = gr.Textbox(label="Warning", visible=False)
download_btn = gr.File(label="Download full-resolution")

generate_btn.click(
fn=generate_image,
inputs=[model, width, height, num_steps, guidance, seed, prompt, add_sampling_metadata],
outputs=[output_image, seed_output, download_btn, warning_text],
)
demo.launch(enable_monitoring=False)
Loading

0 comments on commit 07a86bf

Please sign in to comment.