-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
403 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
output/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
FROM python:3.11 | ||
|
||
# https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo | ||
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y | ||
|
||
|
||
ARG DIR=flux-image-gen | ||
|
||
COPY $DIR/requirements.txt requirements.txt | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
COPY purge-google-fonts.sh . | ||
RUN bash purge-google-fonts.sh | ||
|
||
WORKDIR /app | ||
|
||
COPY $DIR/*.py . | ||
|
||
COPY $DIR/gradio_config.yaml . | ||
|
||
COPY $DIR/test-image.jpg . | ||
|
||
ENTRYPOINT ["fastapi", "run", "api_server.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import io | ||
import os | ||
import sys | ||
import torch | ||
|
||
from fastapi import FastAPI | ||
from fastapi.responses import Response, JSONResponse | ||
from PIL import Image | ||
from pydantic import BaseModel | ||
|
||
from image_gen import FluxGenerator | ||
|
||
# Detect if app is run using `fastapi dev ...` | ||
DEV_MODE = sys.argv[1] == "dev" | ||
|
||
app = FastAPI() | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model = os.environ.get("FLUX_MODEL_NAME", "flux-schnell") | ||
if not DEV_MODE: | ||
print("Loading model", model) | ||
generator = FluxGenerator(model, device, offload=False) | ||
|
||
|
||
class ImageGenInput(BaseModel): | ||
width: int | ||
height: int | ||
num_steps: int | ||
guidance: float | ||
seed: int | ||
prompt: str | ||
add_sampling_metadata: bool | ||
|
||
|
||
@app.get("/model") | ||
async def get_model(): | ||
return {"model": model} | ||
|
||
|
||
@app.post("/generate") | ||
async def generate_image(input: ImageGenInput): | ||
if DEV_MODE: | ||
# For quicker testing or when GPU hardware not available | ||
fn = "test-image.jpg" | ||
seed = "dev" | ||
image = Image.open(fn) | ||
# Uncomment to test error handling | ||
# return JSONResponse({"error": {"message": "Dev mode error test", "seed": "not-so-random"}}, status_code=400) | ||
else: | ||
# Main image generation functionality | ||
image, seed, msg = generator.generate_image( | ||
input.width, | ||
input.height, | ||
input.num_steps, | ||
input.guidance, | ||
input.seed, | ||
input.prompt, | ||
add_sampling_metadata=input.add_sampling_metadata, | ||
) | ||
if not image: | ||
return JSONResponse({"error": {"message": msg, "seed": seed}}, status_code=400) | ||
# Convert image to bytes response | ||
buffer = io.BytesIO() | ||
image.save(buffer, format="jpeg") | ||
bytes = buffer.getvalue() | ||
return Response(bytes, media_type="image/jpeg", headers={"x-flux-seed": seed}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
models: | ||
- name: flux-schnell | ||
address: http://localhost:8000 | ||
example_prompt: | | ||
Yoda riding a skateboard. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import io | ||
import os | ||
import httpx | ||
import uuid | ||
import pathlib | ||
import yaml | ||
|
||
import gradio as gr | ||
from pydantic import BaseModel, HttpUrl | ||
from PIL import Image, ExifTags | ||
from typing import List | ||
from urllib.parse import urljoin | ||
|
||
|
||
class Model(BaseModel): | ||
name: str | ||
address: HttpUrl | ||
|
||
class AppSettings(BaseModel): | ||
models: List[Model] | ||
example_prompt: str | ||
|
||
|
||
settings_path = pathlib.Path("/etc/gradio-app/gradio_config.yaml") | ||
if not settings_path.exists(): | ||
print("No settings overrides found at", settings_path) | ||
settings_path = "./gradio_config.yaml" | ||
print("Using settings from", settings_path) | ||
with open(settings_path, "r") as file: | ||
settings = AppSettings(**yaml.safe_load(file)) | ||
print("App config:", settings.model_dump()) | ||
|
||
MODELS = {m.name: m.address for m in settings.models} | ||
MODEL_NAMES = list(MODELS.keys()) | ||
|
||
# Disable analytics for GDPR compliance | ||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" | ||
|
||
def save_image(model_name: str, prompt: str, seed: int, add_sampling_metadata: bool, image: Image.Image): | ||
filename = f"output/gradio/{uuid.uuid4()}.jpg" | ||
os.makedirs(os.path.dirname(filename), exist_ok=True) | ||
exif_data = Image.Exif() | ||
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" | ||
exif_data[ExifTags.Base.Make] = "Black Forest Labs" | ||
exif_data[ExifTags.Base.Model] = model_name | ||
if add_sampling_metadata: | ||
exif_data[ExifTags.Base.ImageDescription] = prompt | ||
image.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) | ||
return filename | ||
|
||
|
||
async def generate_image( | ||
model_name: str, | ||
width: int, | ||
height: int, | ||
num_steps: int, | ||
guidance: float, | ||
seed: int, | ||
prompt: str, | ||
add_sampling_metadata: bool, | ||
): | ||
url = urljoin(str(MODELS[model_name]), "/generate") | ||
data = { | ||
"width": width, | ||
"height": height, | ||
"num_steps": num_steps, | ||
"guidance": guidance, | ||
"seed": seed, | ||
"prompt": prompt, | ||
"add_sampling_metadata": add_sampling_metadata, | ||
} | ||
async with httpx.AsyncClient(timeout=60) as client: | ||
try: | ||
response = await client.post(url, json=data) | ||
except httpx.ConnectError: | ||
raise gr.Error("Model backend unavailable") | ||
if response.status_code == 400: | ||
data = response.json() | ||
if "error" in data and "message" in data["error"]: | ||
message = data["error"]["message"] | ||
if "seed" in data["error"]: | ||
message += f" (seed: {data['error']['seed']})" | ||
raise gr.Error(message) | ||
try: | ||
response.raise_for_status() | ||
except httpx.HTTPStatusError as err: | ||
# Raise a generic error message to avoid leaking unwanted details | ||
# Admin should consult API logs for more info | ||
raise gr.Error(f"Backend error (HTTP {err.response.status_code})") | ||
image = Image.open(io.BytesIO(response.content)) | ||
seed = response.headers.get("x-flux-seed", "unknown") | ||
filename = save_image(model_name, prompt, seed, add_sampling_metadata, image) | ||
|
||
return image, seed, filename, None | ||
|
||
|
||
with gr.Blocks() as demo: | ||
gr.Markdown("# Flux Image Generation Demo") | ||
|
||
with gr.Row(): | ||
with gr.Column(): | ||
model = gr.Dropdown(MODEL_NAMES, value=MODEL_NAMES[0], label="Model", interactive=len(MODEL_NAMES) > 1) | ||
prompt = gr.Textbox(label="Prompt", value=settings.example_prompt) | ||
|
||
with gr.Accordion("Advanced Options", open=False): | ||
# TODO: Make min/max slide values configurable | ||
width = gr.Slider(128, 8192, 1360, step=16, label="Width") | ||
height = gr.Slider(128, 8192, 768, step=16, label="Height") | ||
num_steps = gr.Slider(1, 50, 4 if model.value == "flux-schnell" else 50, step=1, label="Number of steps") | ||
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not model.value == "flux-schnell") | ||
seed = gr.Textbox("-1", label="Seed (-1 for random)") | ||
add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True) | ||
|
||
generate_btn = gr.Button("Generate") | ||
|
||
with gr.Column(): | ||
output_image = gr.Image(label="Generated Image") | ||
seed_output = gr.Textbox(label="Used Seed") | ||
warning_text = gr.Textbox(label="Warning", visible=False) | ||
download_btn = gr.File(label="Download full-resolution") | ||
|
||
generate_btn.click( | ||
fn=generate_image, | ||
inputs=[model, width, height, num_steps, guidance, seed, prompt, add_sampling_metadata], | ||
outputs=[output_image, seed_output, download_btn, warning_text], | ||
) | ||
demo.launch(enable_monitoring=False) |
Oops, something went wrong.