Skip to content

Commit

Permalink
Cancel pipline. Doesn't work for request canlleations yet
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Sep 24, 2024
1 parent 10e186b commit 0cf8b18
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 14 deletions.
3 changes: 2 additions & 1 deletion flux/schnell/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ requirements:
- accelerate
- sentencepiece
- protobuf
- fastapi
resources:
accelerator: H100_40GB
accelerator: A100
use_gpu: true
secrets:
hf_access_token: null
Expand Down
93 changes: 80 additions & 13 deletions flux/schnell/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import threading

import base64
import fastapi
import logging
import math
import random
Expand All @@ -21,6 +24,9 @@ def __init__(self, **kwargs):
self.model_name = kwargs["config"]["model_metadata"]["repo_id"]
self.hf_access_token = self._secrets["hf_access_token"]
self.pipe = None
self._thread = None
self._exception = None
self._result = None

def load(self):
self.pipe = FluxPipeline.from_pretrained(
Expand All @@ -41,11 +47,35 @@ def convert_to_b64(self, image: Image) -> str:
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64

def predict(self, model_input):
def _generate_image(self, prompt, prompt2, guidance_scale, max_sequence_length,
num_inference_steps, width, height, generator):
time.sleep(15)
try:
image = self.pipe(
prompt=prompt,
prompt_2=prompt2,
guidance_scale=guidance_scale,
max_sequence_length=max_sequence_length,
num_inference_steps=num_inference_steps,
width=width,
height=height,
output_type="pil",
generator=generator,
).images[0]
self._result = image
return
except Exception as e:
logging.info(f"Image generation was aborted or failed: {e}")
self._exception = e
return

async def predict(self, model_input, request: fastapi.Request):
start = time.perf_counter()
timeout_sec = model_input.get("timeout_sec", 60)
seed = model_input.get("seed")
prompt = model_input.get("prompt")
prompt2 = model_input.get("prompt2")
logging.info(f"Starting: {prompt}")
max_sequence_length = model_input.get(
"max_sequence_length", 256
) # 256 is max for FLUX.1-schnell
Expand Down Expand Up @@ -78,20 +108,57 @@ def predict(self, model_input):
prompt2 = " ".join(tokens[: min(len(tokens), max_sequence_length)])
generator = torch.Generator().manual_seed(seed)

image = self.pipe(
prompt=prompt,
prompt_2=prompt2,
guidance_scale=guidance_scale,
max_sequence_length=max_sequence_length,
num_inference_steps=num_inference_steps,
width=width,
height=height,
output_type="pil",
generator=generator,
).images[0]
logging.info(f"Starting: thread.")
self._reset()
self._thread = threading.Thread(target=self._generate_image, args=(
prompt, prompt2, guidance_scale, max_sequence_length, num_inference_steps,
width, height, generator))
self._thread.start()
logging.info(f"started thread.")

b64_results = self.convert_to_b64(image)
logging.info(f"Polling")
while self._thread.is_alive():
elapsed_sec = time.perf_counter() - start
if await request.is_disconnected():
logging.info("Aborting due to client disconnect.")
self._abort_thread()
raise fastapi.HTTPException(status_code=408, detail="Client disconnected.")
elif elapsed_sec > timeout_sec:
logging.info("Aborting due to timeout.")
self._abort_thread()
raise fastapi.HTTPException(status_code=408, detail="Timeout.")
time.sleep(1.0)
if self._result:
logging.info("Result there.")
logging.info(f"Thread alive: {self._thread.is_alive()}")
break

logging.info(f"Polling done.")
self._abort_thread()

if not self._result:
assert self._exception
raise self._exception
else:
image = self._result

b64_results = self.convert_to_b64(image)
end = time.perf_counter()
logging.info(f"Total time taken: {end - start} seconds")
return {"data": b64_results}

def _abort_thread(self):
if not self._thread or not self._thread.is_alive():
return
t0 = time.perf_counter()
logging.info(f"Setting interrupt")
self.pipe._interrupt = True
logging.info(f"Waiting to join")
self._thread.join()
logging.info(f"Joined after {time.perf_counter() - t0} seconds.")

def _reset(self):
self.pipe._interrupt = False
self._thread = None
self._result = None
self._exception = None

0 comments on commit 0cf8b18

Please sign in to comment.