-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
150 lines (126 loc) · 5.51 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from torch import autocast
from download import load_model_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import time
from tqdm import tqdm, trange
from txt2img import chunk, numpy_to_pil, put_watermark, load_replacement, check_safety
from imwatermark import WatermarkEncoder
from PIL import Image
from einops import rearrange
import os
import numpy as np
from omegaconf import OmegaConf
import base64
from io import BytesIO
import json
from pytorch_lightning import seed_everything
from contextlib import nullcontext
CONFIG_FILE = "configs/stable-diffusion/v1-inference.yaml"
def init():
global model # needed for bananna optimizations
config = OmegaConf.load(CONFIG_FILE)
model = load_model_from_config(config, "/models/model-epoch07-full.ckpt")
model.cuda()
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
def decodeBase64Image(imageStr: str) -> Image:
return Image.open(BytesIO(base64.decodebytes(bytes(imageStr, "utf-8"))))
def truncateInputs(inputs: dict):
clone = inputs.copy()
if "modelInputs" in clone:
modelInputs = clone["modelInputs"] = clone["modelInputs"].copy()
for item in ["init_image", "mask_image"]:
if item in modelInputs:
modelInputs[item] = modelInputs[item][0:6] + "..."
return clone
def inference(all_inputs: dict) -> dict:
global model
print(json.dumps(truncateInputs(all_inputs), indent=2))
model_inputs = all_inputs.get("modelInputs", None)
call_inputs = all_inputs.get("callInputs", None)
startRequestId = call_inputs.get("startRequestId", None)
# sampler = PLMSSampler(model)
sampler = DDIMSampler(model)
opt = {
"n_iter": 1,
"C": 4,
"H": model_inputs.get("height", 512),
"W": model_inputs.get("width", 512),
"f": 8,
"ddim_steps": model_inputs.get("num_inference_steps", 50),
"ddim_eta": 0.0,
"scale": model_inputs.get("guidance_scale", 7.5),
"n_samples": 1,
"skip_save": False,
"seed": model_inputs.get("seed"),
"precision": "autocast",
}
print(opt)
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark("bytes", wm.encode("utf-8"))
prompt = model_inputs.get("prompt", None)
batch_size = 1
data = batch_size * [prompt]
start_code = None
seed_everything(opt["seed"])
precision_scope = autocast if opt["precision"] == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt["n_iter"], desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt["scale"] != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt["C"], opt["H"] // opt["f"], opt["W"] // opt["f"]]
samples_ddim, _ = sampler.sample(
S=opt["ddim_steps"],
conditioning=c,
batch_size=opt["n_samples"],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt["scale"],
unconditional_conditioning=uc,
eta=opt["ddim_eta"],
x_T=start_code,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0
)
x_samples_ddim = (
x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
)
# TODO
# x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
x_checked_image = x_samples_ddim
x_checked_image_torch = torch.from_numpy(
x_checked_image
).permute(0, 3, 1, 2)
if not opt["skip_save"]:
for x_sample in x_checked_image_torch:
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c"
)
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
# img.save(
# os.path.join("samples", f"{base_count:05}.png")
# )
buffered = BytesIO()
img.save(buffered, format="PNG")
image_base64 = base64.b64encode(
buffered.getvalue()
).decode("utf-8")
# base_count += 1
toc = time.time()
return {"image_base64": image_base64}