From b5622499e51c8b30652f644598c8b8d708baf6d5 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Tue, 24 Sep 2024 14:21:03 -0700 Subject: [PATCH] Cancel pipline. Doesn't work for request canlleations yet --- flux/schnell/config.yaml | 3 +- flux/schnell/model/model.py | 93 +++++++++++++++++++++++++++++++------ 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/flux/schnell/config.yaml b/flux/schnell/config.yaml index fb77fe90..80fdacf1 100644 --- a/flux/schnell/config.yaml +++ b/flux/schnell/config.yaml @@ -10,8 +10,9 @@ requirements: - accelerate - sentencepiece - protobuf + - fastapi resources: - accelerator: H100_40GB + accelerator: A100 use_gpu: true secrets: hf_access_token: null diff --git a/flux/schnell/model/model.py b/flux/schnell/model/model.py index 8feaeb6a..3b3dd25f 100644 --- a/flux/schnell/model/model.py +++ b/flux/schnell/model/model.py @@ -1,4 +1,7 @@ +import threading + import base64 +import fastapi import logging import math import random @@ -21,6 +24,9 @@ def __init__(self, **kwargs): self.model_name = kwargs["config"]["model_metadata"]["repo_id"] self.hf_access_token = self._secrets["hf_access_token"] self.pipe = None + self._thread = None + self._exception = None + self._result = None def load(self): self.pipe = FluxPipeline.from_pretrained( @@ -41,11 +47,35 @@ def convert_to_b64(self, image: Image) -> str: img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_b64 - def predict(self, model_input): + def _generate_image(self, prompt, prompt2, guidance_scale, max_sequence_length, + num_inference_steps, width, height, generator): + time.sleep(15) + try: + image = self.pipe( + prompt=prompt, + prompt_2=prompt2, + guidance_scale=guidance_scale, + max_sequence_length=max_sequence_length, + num_inference_steps=num_inference_steps, + width=width, + height=height, + output_type="pil", + generator=generator, + ).images[0] + self._result = image + return + except Exception as e: + logging.info(f"Image generation was aborted or failed: {e}") + self._exception = e + return + + async def predict(self, model_input, request: fastapi.Request): start = time.perf_counter() + timeout_sec = model_input.get("timeout_sec", 60) seed = model_input.get("seed") prompt = model_input.get("prompt") prompt2 = model_input.get("prompt2") + logging.info(f"Starting: {prompt}") max_sequence_length = model_input.get( "max_sequence_length", 256 ) # 256 is max for FLUX.1-schnell @@ -78,20 +108,57 @@ def predict(self, model_input): prompt2 = " ".join(tokens[: min(len(tokens), max_sequence_length)]) generator = torch.Generator().manual_seed(seed) - image = self.pipe( - prompt=prompt, - prompt_2=prompt2, - guidance_scale=guidance_scale, - max_sequence_length=max_sequence_length, - num_inference_steps=num_inference_steps, - width=width, - height=height, - output_type="pil", - generator=generator, - ).images[0] + logging.info(f"Starting: thread.") + self._reset() + self._thread = threading.Thread(target=self._generate_image, args=( + prompt, prompt2, guidance_scale, max_sequence_length, num_inference_steps, + width, height, generator)) + self._thread.start() + logging.info(f"started thread.") - b64_results = self.convert_to_b64(image) + logging.info(f"Polling") + while self._thread.is_alive(): + elapsed_sec = time.perf_counter() - start + if await request.is_disconnected(): + logging.info("Aborting due to client disconnect.") + self._abort_thread() + raise fastapi.HTTPException(status_code=408, detail="Client disconnected.") + elif elapsed_sec > timeout_sec: + logging.info("Aborting due to timeout.") + self._abort_thread() + raise fastapi.HTTPException(status_code=408, detail="Timeout.") + time.sleep(1.0) + if self._result: + logging.info("Result there.") + logging.info(f"Thread alive: {self._thread.is_alive()}") + break + + logging.info(f"Polling done.") + self._abort_thread() + if not self._result: + assert self._exception + raise self._exception + else: + image = self._result + + b64_results = self.convert_to_b64(image) end = time.perf_counter() logging.info(f"Total time taken: {end - start} seconds") return {"data": b64_results} + + def _abort_thread(self): + if not self._thread or not self._thread.is_alive(): + return + t0 = time.perf_counter() + logging.info(f"Setting interrupt") + self.pipe._interrupt = True + logging.info(f"Waiting to join") + self._thread.join() + logging.info(f"Joined after {time.perf_counter() - t0} seconds.") + + def _reset(self): + self.pipe._interrupt = False + self._thread = None + self._result = None + self._exception = None