-
Notifications
You must be signed in to change notification settings - Fork 26
/
api.py
122 lines (90 loc) · 3.6 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Literal, Optional, TYPE_CHECKING
import numpy as np
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
from platform import system
if TYPE_CHECKING:
from flux_pipeline import FluxPipeline
if system() == "Windows":
MAX_RAND = 2**16 - 1
else:
MAX_RAND = 2**32 - 1
class AppState:
model: "FluxPipeline"
class FastAPIApp(FastAPI):
state: AppState
class LoraArgs(BaseModel):
scale: Optional[float] = 1.0
path: Optional[str] = None
name: Optional[str] = None
action: Optional[Literal["load", "unload"]] = "load"
class LoraLoadResponse(BaseModel):
status: Literal["success", "error"]
message: Optional[str] = None
class GenerateArgs(BaseModel):
prompt: str
width: Optional[int] = Field(default=720)
height: Optional[int] = Field(default=1024)
num_steps: Optional[int] = Field(default=24)
guidance: Optional[float] = Field(default=3.5)
seed: Optional[int] = Field(
default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND
)
strength: Optional[float] = 1.0
init_image: Optional[str] = None
app = FastAPIApp()
@app.post("/generate")
def generate(args: GenerateArgs):
"""
Generates an image from the Flux flow transformer.
Args:
args (GenerateArgs): Arguments for image generation:
- `prompt`: The prompt used for image generation.
- `width`: The width of the image.
- `height`: The height of the image.
- `num_steps`: The number of steps for the image generation.
- `guidance`: The guidance for image generation, represents the
influence of the prompt on the image generation.
- `seed`: The seed for the image generation.
- `strength`: strength for image generation, 0.0 - 1.0.
Represents the percent of diffusion steps to run,
setting the init_image as the noised latent at the
given number of steps.
- `init_image`: Base64 encoded image or path to image to use as the init image.
Returns:
StreamingResponse: The generated image as streaming jpeg bytes.
"""
result = app.state.model.generate(**args.model_dump())
return StreamingResponse(result, media_type="image/jpeg")
@app.post("/lora", response_model=LoraLoadResponse)
def lora_action(args: LoraArgs):
"""
Loads or unloads a LoRA checkpoint into / from the Flux flow transformer.
Args:
args (LoraArgs): Arguments for the LoRA action:
- `scale`: The scaling factor for the LoRA weights.
- `path`: The path to the LoRA checkpoint.
- `name`: The name of the LoRA checkpoint.
- `action`: The action to perform, either "load" or "unload".
Returns:
LoraLoadResponse: The status of the LoRA action.
"""
try:
if args.action == "load":
app.state.model.load_lora(args.path, args.scale, args.name)
elif args.action == "unload":
app.state.model.unload_lora(args.name if args.name else args.path)
else:
return JSONResponse(
content={
"status": "error",
"message": f"Invalid action, expected 'load' or 'unload', got {args.action}",
},
status_code=400,
)
except Exception as e:
return JSONResponse(
status_code=500, content={"status": "error", "message": str(e)}
)
return JSONResponse(status_code=200, content={"status": "success"})