From 749331cffb3df2f684eb93e42f06ccc58082d14e Mon Sep 17 00:00:00 2001 From: Vikramjeet Date: Fri, 25 Oct 2024 18:00:22 +0530 Subject: [PATCH] asynchronous runpod handlers --- .../image-to-video/run_image-to-video.py | 162 ++++++++----- serverless/inpainting/run_inpainting.py | 187 ++++++++++----- serverless/outpainting/run_outpainting.py | 198 ++++++++-------- serverless/text-to-image/run_text-to-image.py | 213 +++++++++++------- 4 files changed, 470 insertions(+), 290 deletions(-) diff --git a/serverless/image-to-video/run_image-to-video.py b/serverless/image-to-video/run_image-to-video.py index 33a3553..668d2e1 100644 --- a/serverless/image-to-video/run_image-to-video.py +++ b/serverless/image-to-video/run_image-to-video.py @@ -3,7 +3,8 @@ import runpod import tempfile import time -from typing import Dict, Any, List, Union, Tuple +import asyncio +from typing import Dict, Any, List, Union, Tuple, AsyncGenerator from PIL import Image import base64 from pydantic import BaseModel, Field @@ -11,6 +12,9 @@ from scripts.api_utils import mp4_to_s3_json from scripts.image_to_video import ImageToVideoPipeline +# Global pipeline instance +global_pipeline = None + class ImageToVideoRequest(BaseModel): """ Pydantic model representing a request for image-to-video generation. @@ -25,26 +29,26 @@ class ImageToVideoRequest(BaseModel): use_dynamic_cfg: bool = Field(True, description="Use dynamic CFG") fps: int = Field(30, description="Frames per second for the output video") -device = "cuda" -pipeline = ImageToVideoPipeline(device=device) +async def initialize_pipeline(): + """Initialize the pipeline if not already loaded""" + global global_pipeline + if global_pipeline is None: + print("Initializing Image to Video pipeline...") + global_pipeline = ImageToVideoPipeline(device="cuda") + print("Pipeline initialized successfully") + return global_pipeline -def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: +async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: """ - Decode and validate the incoming video generation request. - - Args: - request: Raw request data containing base64 image and parameters - - Returns: - Dict containing decoded PIL Image and validated parameters - - Raises: - ValueError: If request validation or image decoding fails + Decode and validate the incoming video generation request asynchronously. """ try: video_request = ImageToVideoRequest(**request) - image_data = base64.b64decode(video_request.image) - image = Image.open(io.BytesIO(image_data)).convert("RGB") + # Run decode in thread pool + image_data = await asyncio.to_thread(base64.b64decode, video_request.image) + image = await asyncio.to_thread( + lambda: Image.open(io.BytesIO(image_data)).convert("RGB") + ) return { 'image': image, @@ -53,19 +57,15 @@ def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: raise ValueError(f"Invalid request: {str(e)}") -def generate_frames(inputs: Dict[str, Any]) -> Tuple[Union[List[Image.Image], Image.Image], float]: +async def generate_frames(inputs: Dict[str, Any], pipeline: ImageToVideoPipeline) -> Tuple[List[Image.Image], float]: """ - Generate video frames using the pipeline. - - Args: - inputs: Dictionary containing input image and generation parameters - - Returns: - Tuple containing generated frames and completion time + Generate video frames using the pipeline asynchronously. """ start_time = time.time() - frames = pipeline.generate( + # Run generation in thread pool + frames = await asyncio.to_thread( + pipeline.generate, prompt=inputs['params']['prompt'], image=inputs['image'], num_frames=inputs['params']['num_frames'], @@ -84,27 +84,23 @@ def generate_frames(inputs: Dict[str, Any]) -> Tuple[Union[List[Image.Image], Im completion_time = time.time() - start_time return frames, completion_time -def create_video_response(frames: List[Image.Image], completion_time: float, fps: int) -> Dict[str, Any]: +async def create_video_response(frames: List[Image.Image], completion_time: float, fps: int) -> Dict[str, Any]: """ - Create video file and generate response with S3 URL. - - Args: - frames: List of generated video frames - completion_time: Time taken for generation - fps: Frames per second for the output video - - Returns: - Dict containing S3 URL, completion time, and video metadata + Create video file and generate response with S3 URL asynchronously. """ - with tempfile.TemporaryDirectory() as temp_dir: - temp_video_path = os.path.join(temp_dir, "generated_video.mp4") - export_to_video(frames, temp_video_path, fps=fps) - - with open(temp_video_path, "rb") as video_file: - s3_response = mp4_to_s3_json( - video_file, - f"generated_video_{int(time.time())}.mp4" - ) + def create_video(): + with tempfile.TemporaryDirectory() as temp_dir: + temp_video_path = os.path.join(temp_dir, "generated_video.mp4") + export_to_video(frames, temp_video_path, fps=fps) + + with open(temp_video_path, "rb") as video_file: + return mp4_to_s3_json( + video_file, + f"generated_video_{int(time.time())}.mp4" + ) + + # Run video creation and upload in thread pool + s3_response = await asyncio.to_thread(create_video) return { "result": s3_response, @@ -113,24 +109,76 @@ def create_video_response(frames: List[Image.Image], completion_time: float, fps "fps": fps } -def handler(job: Dict[str, Any]) -> Dict[str, Any]: +async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: """ - RunPod handler function for processing video generation requests. - - Args: - job: RunPod job dictionary containing input data - - Returns: - Dict containing either the processed video result or error information + Async generator handler for RunPod with progress updates. """ try: - inputs = decode_request(job['input']) - frames, completion_time = generate_frames(inputs) - return create_video_response(frames, completion_time, inputs['params']['fps']) + # Initial status + yield {"status": "starting", "message": "Initializing video generation process"} + + # Initialize pipeline + pipeline = await initialize_pipeline() + yield {"status": "processing", "message": "Pipeline loaded successfully"} + + # Decode request + try: + inputs = await decode_request(job['input']) + yield {"status": "processing", "message": "Request decoded successfully"} + except Exception as e: + yield {"status": "error", "message": f"Error decoding request: {str(e)}"} + return + + # Generate frames with progress updates + try: + yield {"status": "processing", "message": "Generating video frames"} + frames, completion_time = await generate_frames(inputs, pipeline) + yield {"status": "processing", "message": f"Generated {len(frames)} frames successfully"} + except Exception as e: + yield {"status": "error", "message": f"Error generating frames: {str(e)}"} + return + + # Create and upload video + try: + yield {"status": "processing", "message": "Creating and uploading video"} + response = await create_video_response( + frames, + completion_time, + inputs['params']['fps'] + ) + yield {"status": "processing", "message": "Video uploaded successfully"} + except Exception as e: + yield {"status": "error", "message": f"Error creating video: {str(e)}"} + return + + # Final response + yield { + "status": "completed", + "output": response + } + except Exception as e: - return {"error": str(e)} + yield { + "status": "error", + "message": f"Unexpected error: {str(e)}" + } + +def calculate_progress(current_frame: int, total_frames: int) -> dict: + """Calculate progress percentage and create status update.""" + progress = (current_frame / total_frames) * 100 + return { + "status": "processing", + "progress": round(progress, 2), + "message": f"Generating frame {current_frame}/{total_frames}" + } + +# Initialize the pipeline when the service starts +print("Initializing service...") +asyncio.get_event_loop().run_until_complete(initialize_pipeline()) +print("Service initialization complete") if __name__ == "__main__": runpod.serverless.start({ - "handler": handler + "handler": async_generator_handler, + "return_aggregate_stream": True }) \ No newline at end of file diff --git a/serverless/inpainting/run_inpainting.py b/serverless/inpainting/run_inpainting.py index 23bf8d3..c59e5d1 100644 --- a/serverless/inpainting/run_inpainting.py +++ b/serverless/inpainting/run_inpainting.py @@ -2,12 +2,14 @@ import base64 import time import logging -from typing import Dict, Any, Tuple, Optional +import asyncio +from typing import Dict, Any, Tuple, Optional, AsyncGenerator from pydantic import BaseModel, Field from PIL import Image from scripts.s3_manager import S3ManagerService from scripts.flux_inference import FluxInpaintingInference from config_settings import settings +import runpod logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -23,29 +25,50 @@ class InpaintingRequest(BaseModel): input_image: str = Field(..., description="Base64 encoded input image") mask_image: str = Field(..., description="Base64 encoded mask image") -device = "cuda" -flux_inpainter = FluxInpaintingInference() -s3_manager = S3ManagerService() +# Global instances +global_inpainter = None +global_s3_manager = None -def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: - """ - Decode and validate the incoming inpainting request. +async def initialize_services(): + """Initialize global services if not already initialized""" + global global_inpainter, global_s3_manager - Args: - request: Raw request data containing images and parameters + if global_inpainter is None: + logger.info("Initializing Flux Inpainting model...") + global_inpainter = FluxInpaintingInference() + logger.info("Flux Inpainting model initialized successfully") - Returns: - Dict containing decoded images and validated parameters + if global_s3_manager is None: + logger.info("Initializing S3 manager...") + global_s3_manager = S3ManagerService() + logger.info("S3 manager initialized successfully") - Raises: - Exception: If request validation or image decoding fails + return global_inpainter, global_s3_manager + +async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: + """ + Decode and validate the incoming inpainting request asynchronously. """ try: + logger.info("Decoding inpainting request") inpainting_request = InpaintingRequest(**request) - input_image = Image.open(io.BytesIO(base64.b64decode(inpainting_request.input_image))) - mask_image = Image.open(io.BytesIO(base64.b64decode(inpainting_request.mask_image))) + # Run image decoding in thread pool + input_image_data = await asyncio.to_thread( + base64.b64decode, inpainting_request.input_image + ) + mask_image_data = await asyncio.to_thread( + base64.b64decode, inpainting_request.mask_image + ) + input_image = await asyncio.to_thread( + lambda: Image.open(io.BytesIO(input_image_data)) + ) + mask_image = await asyncio.to_thread( + lambda: Image.open(io.BytesIO(mask_image_data)) + ) + + logger.info("Request decoded successfully") return { "prompt": inpainting_request.prompt, "input_image": input_image, @@ -58,19 +81,15 @@ def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: logger.error(f"Error in decode_request: {e}") raise -def generate_inpainting(inputs: Dict[str, Any]) -> Tuple[Optional[Image.Image], Dict[str, Any]]: +async def generate_inpainting(inputs: Dict[str, Any], inpainter: FluxInpaintingInference) -> Tuple[Optional[Image.Image], Dict[str, Any]]: """ - Perform inpainting operation using the Flux model. - - Args: - inputs: Dictionary containing input images and parameters - - Returns: - Tuple containing the inpainted image and metadata + Perform inpainting operation using the Flux model asynchronously. """ start_time = time.time() - result_image = flux_inpainter.generate_inpainting( + # Run inpainting in thread pool + result_image = await asyncio.to_thread( + inpainter.generate_inpainting, input_image=inputs["input_image"], mask_image=inputs["mask_image"], prompt=inputs["prompt"], @@ -88,27 +107,32 @@ def generate_inpainting(inputs: Dict[str, Any]) -> Tuple[Optional[Image.Image], return result_image, output -def upload_result(image: Image.Image, metadata: Dict[str, Any]) -> Dict[str, Any]: +async def upload_result(image: Image.Image, metadata: Dict[str, Any], s3_manager: S3ManagerService) -> Dict[str, Any]: """ - Upload the generated image to S3 and prepare the response. - - Args: - image: Generated inpainting image - metadata: Dictionary containing generation metadata - - Returns: - Dict containing S3 URL and generation metadata - - Raises: - Exception: If image upload or URL generation fails + Upload the generated image to S3 and prepare the response asynchronously. """ try: + # Prepare image buffer buffered = io.BytesIO() - image.save(buffered, format="PNG") + await asyncio.to_thread(image.save, buffered, format="PNG") + buffered.seek(0) - unique_filename = s3_manager.generate_unique_file_name("result.png") - s3_manager.upload_file(io.BytesIO(buffered.getvalue()), unique_filename) - signed_url = s3_manager.generate_signed_url(unique_filename, exp=43200) + # Generate unique filename and upload + unique_filename = await asyncio.to_thread( + s3_manager.generate_unique_file_name, "result.png" + ) + await asyncio.to_thread( + s3_manager.upload_file, + io.BytesIO(buffered.getvalue()), + unique_filename + ) + + # Generate signed URL + signed_url = await asyncio.to_thread( + s3_manager.generate_signed_url, + unique_filename, + exp=43200 + ) return { "result_url": signed_url, @@ -120,31 +144,76 @@ def upload_result(image: Image.Image, metadata: Dict[str, Any]) -> Dict[str, Any logger.error(f"Error in upload_result: {e}") raise -def handler(job: Dict[str, Any]) -> Dict[str, Any]: +async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: """ - RunPod handler function for processing inpainting requests. - - Args: - job: RunPod job dictionary containing input data - - Returns: - Dict containing either the processed results or error information + Async generator handler for RunPod with progress updates. """ try: - inputs = decode_request(job['input']) - result_image, metadata = generate_inpainting(inputs) - - if result_image is None: - return {"error": "Failed to generate image"} + # Initial status + yield {"status": "starting", "message": "Initializing inpainting process"} + + # Initialize services + inpainter, s3_manager = await initialize_services() + yield {"status": "processing", "message": "Services initialized successfully"} + + # Decode request + try: + inputs = await decode_request(job['input']) + yield {"status": "processing", "message": "Request decoded successfully"} + except Exception as e: + logger.error(f"Request decode error: {e}") + yield {"status": "error", "message": f"Error decoding request: {str(e)}"} + return + + # Generate inpainting + try: + yield {"status": "processing", "message": "Starting inpainting generation"} + result_image, metadata = await generate_inpainting(inputs, inpainter) - return upload_result(result_image, metadata) - + if result_image is None: + yield {"status": "error", "message": "Failed to generate image"} + return + + yield { + "status": "processing", + "message": "Inpainting generated successfully", + "completion": f"{metadata['time_taken']:.2f}s" + } + except Exception as e: + logger.error(f"Inpainting error: {e}") + yield {"status": "error", "message": f"Error during inpainting: {str(e)}"} + return + + # Upload result + try: + yield {"status": "processing", "message": "Uploading result"} + response = await upload_result(result_image, metadata, s3_manager) + yield {"status": "processing", "message": "Result uploaded successfully"} + except Exception as e: + logger.error(f"Upload error: {e}") + yield {"status": "error", "message": f"Error uploading result: {str(e)}"} + return + + # Final response + yield { + "status": "completed", + "output": response + } + except Exception as e: - logger.error(f"Error in handler: {e}") - return {"error": str(e)} + logger.error(f"Unexpected error: {e}") + yield { + "status": "error", + "message": f"Unexpected error: {str(e)}" + } + +# Initialize services when the service starts +print("Initializing service...") +asyncio.get_event_loop().run_until_complete(initialize_services()) +print("Service initialization complete") if __name__ == "__main__": - import runpod runpod.serverless.start({ - "handler": handler + "handler": async_generator_handler, + "return_aggregate_stream": True }) \ No newline at end of file diff --git a/serverless/outpainting/run_outpainting.py b/serverless/outpainting/run_outpainting.py index 1a89d89..bef14a0 100644 --- a/serverless/outpainting/run_outpainting.py +++ b/serverless/outpainting/run_outpainting.py @@ -2,12 +2,16 @@ import base64 import time import runpod -from typing import Dict, Any, Tuple +import asyncio +from typing import Dict, Any, AsyncGenerator from PIL import Image from pydantic import BaseModel, Field from scripts.outpainting import Outpainter from scripts.api_utils import pil_to_s3_json +# Global cache for the Outpainter instance +global_outpainter = None + class OutpaintingRequest(BaseModel): """ Pydantic model representing a request for outpainting inference. @@ -26,102 +30,120 @@ class OutpaintingRequest(BaseModel): overlap_top: bool = Field(True, description="Apply overlap on top side") overlap_bottom: bool = Field(True, description="Apply overlap on bottom side") -device = "cuda" -outpainter = Outpainter() +async def initialize_model(): + """Initialize the model if not already loaded""" + global global_outpainter + if global_outpainter is None: + print("Initializing Outpainter model...") + global_outpainter = Outpainter() + print("Model initialized successfully") + return global_outpainter -def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: +async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: """ - Decode and validate the incoming request. - - Args: - request: Raw request data containing image and parameters - - Returns: - Dict containing decoded PIL Image and validated parameters - - Raises: - ValueError: If request validation fails or image decoding fails + Async generator handler for RunPod. + Yields status updates and progress during the outpainting process. """ try: - outpainting_request = OutpaintingRequest(**request) - image_data = base64.b64decode(outpainting_request.image) - image = Image.open(io.BytesIO(image_data)).convert("RGBA") - - return { + # Initial status + yield {"status": "starting", "message": "Initializing outpainting process"} + + # Initialize model + outpainter = await initialize_model() + yield {"status": "processing", "message": "Model loaded successfully"} + + # Decode request + try: + request = OutpaintingRequest(**job['input']) + image_data = base64.b64decode(request.image) + image = Image.open(io.BytesIO(image_data)).convert("RGBA") + yield {"status": "processing", "message": "Request decoded successfully"} + except Exception as e: + yield {"status": "error", "message": f"Error decoding request: {str(e)}"} + return + + # Start timing + start_time = time.time() + + # Prepare outpainting parameters + inputs = { 'image': image, - 'params': outpainting_request.model_dump() + 'params': { + 'width': request.width, + 'height': request.height, + 'overlap_percentage': request.overlap_percentage, + 'num_inference_steps': request.num_inference_steps, + 'resize_option': request.resize_option, + 'custom_resize_percentage': request.custom_resize_percentage, + 'prompt_input': request.prompt_input, + 'alignment': request.alignment, + 'overlap_left': request.overlap_left, + 'overlap_right': request.overlap_right, + 'overlap_top': request.overlap_top, + 'overlap_bottom': request.overlap_bottom + } } - except Exception as e: - raise ValueError(f"Invalid request: {str(e)}") -def perform_outpainting(inputs: Dict[str, Any]) -> Tuple[Image.Image, float]: - """ - Perform outpainting operation on the input image. - - Args: - inputs: Dictionary containing image and outpainting parameters - - Returns: - Tuple containing the outpainted image and processing time in seconds - """ - start_time = time.time() - - result = outpainter.outpaint( - inputs['image'], - inputs['params']['width'], - inputs['params']['height'], - inputs['params']['overlap_percentage'], - inputs['params']['num_inference_steps'], - inputs['params']['resize_option'], - inputs['params']['custom_resize_percentage'], - inputs['params']['prompt_input'], - inputs['params']['alignment'], - inputs['params']['overlap_left'], - inputs['params']['overlap_right'], - inputs['params']['overlap_top'], - inputs['params']['overlap_bottom'] - ) - - completion_time = time.time() - start_time - return result, completion_time - -def format_response(result: Image.Image, completion_time: float) -> Dict[str, Any]: - """ - Format the outpainting result for API response. - - Args: - result: Outpainted PIL Image - completion_time: Processing time in seconds - - Returns: - Dict containing S3 URL, completion time, and image resolution - """ - img_str = pil_to_s3_json(result, "outpainting_image") - - return { - "result": img_str, - "completion_time": round(completion_time, 2), - "image_resolution": f"{result.width}x{result.height}" - } - -def handler(job: Dict[str, Any]) -> Dict[str, Any]: - """ - RunPod handler function for processing outpainting requests. - - Args: - job: RunPod job dictionary containing input data - - Returns: - Dict containing either the processed results or error information - """ - try: - inputs = decode_request(job['input']) - result, completion_time = perform_outpainting(inputs) - return format_response(result, completion_time) + yield {"status": "processing", "message": "Starting outpainting process"} + + # Perform outpainting + try: + result = outpainter.outpaint( + inputs['image'], + inputs['params']['width'], + inputs['params']['height'], + inputs['params']['overlap_percentage'], + inputs['params']['num_inference_steps'], + inputs['params']['resize_option'], + inputs['params']['custom_resize_percentage'], + inputs['params']['prompt_input'], + inputs['params']['alignment'], + inputs['params']['overlap_left'], + inputs['params']['overlap_right'], + inputs['params']['overlap_top'], + inputs['params']['overlap_bottom'] + ) + yield {"status": "processing", "message": "Outpainting completed successfully"} + except Exception as e: + yield {"status": "error", "message": f"Error during outpainting: {str(e)}"} + return + + # Upload to S3 and get URL + try: + img_str = pil_to_s3_json(result, "outpainting_image") + yield {"status": "processing", "message": "Image uploaded successfully"} + except Exception as e: + yield {"status": "error", "message": f"Error uploading image: {str(e)}"} + return + + # Calculate completion time + completion_time = time.time() - start_time + + # Final response + final_response = { + "status": "completed", + "output": { + "result": img_str, + "completion_time": round(completion_time, 2), + "image_resolution": f"{result.width}x{result.height}" + } + } + + yield final_response + except Exception as e: - return {"error": str(e)} + yield { + "status": "error", + "message": f"Unexpected error: {str(e)}" + } + +# Initialize the model when the service starts +print("Initializing service...") +asyncio.get_event_loop().run_until_complete(initialize_model()) +print("Service initialization complete") if __name__ == "__main__": runpod.serverless.start({ - "handler": handler + "handler": async_generator_handler, + "return_aggregate_stream": True }) \ No newline at end of file diff --git a/serverless/text-to-image/run_text-to-image.py b/serverless/text-to-image/run_text-to-image.py index a9ede49..a76cf5e 100644 --- a/serverless/text-to-image/run_text-to-image.py +++ b/serverless/text-to-image/run_text-to-image.py @@ -1,55 +1,63 @@ import runpod import torch -from diffusers import DiffusionPipeline -from typing import Dict, Any, List +import asyncio +import logging +from typing import Dict, Any, List, AsyncGenerator from PIL import Image +from diffusers import DiffusionPipeline from config_settings import settings from configs.tti_settings import tti_settings from scripts.api_utils import pil_to_b64_json, pil_to_s3_json +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Global pipeline instance +global_pipeline = None device = "cuda" if torch.cuda.is_available() else "cpu" -def setup_pipeline(): +async def initialize_pipeline(): """ - Set up and optimize the SDXL pipeline with LoRA for inference. - - Returns: - DiffusionPipeline: Optimized SDXL pipeline ready for inference + Initialize and optimize the SDXL pipeline with LoRA. """ - sdxl_pipeline = DiffusionPipeline.from_pretrained( - tti_settings.MODEL_NAME, - torch_dtype=torch.bfloat16 - ).to(device) - - sdxl_pipeline.load_lora_weights(tti_settings.ADAPTER_NAME) - sdxl_pipeline.fuse_lora() + global global_pipeline - sdxl_pipeline.unet.to(memory_format=torch.channels_last) - if tti_settings.ENABLE_COMPILE: - sdxl_pipeline.unet = torch.compile( - sdxl_pipeline.unet, - mode="max-autotune" - ) - sdxl_pipeline.vae.decode = torch.compile( - sdxl_pipeline.vae.decode, - mode="max-autotune" + if global_pipeline is None: + logger.info("Initializing SDXL pipeline...") + + # Run model loading in thread pool + global_pipeline = await asyncio.to_thread( + DiffusionPipeline.from_pretrained, + tti_settings.MODEL_NAME, + torch_dtype=torch.bfloat16 ) - sdxl_pipeline.fuse_qkv_projections() - return sdxl_pipeline - -pipeline = setup_pipeline() + global_pipeline.to(device) + + logger.info("Loading LoRA weights...") + await asyncio.to_thread(global_pipeline.load_lora_weights, tti_settings.ADAPTER_NAME) + await asyncio.to_thread(global_pipeline.fuse_lora) + + logger.info("Optimizing pipeline...") + global_pipeline.unet.to(memory_format=torch.channels_last) + if tti_settings.ENABLE_COMPILE: + global_pipeline.unet = await asyncio.to_thread( + torch.compile, + global_pipeline.unet, + mode="max-autotune" + ) + global_pipeline.vae.decode = await asyncio.to_thread( + torch.compile, + global_pipeline.vae.decode, + mode="max-autotune" + ) + await asyncio.to_thread(global_pipeline.fuse_qkv_projections) + logger.info("Pipeline initialization complete") + + return global_pipeline def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: - """ - Decode and validate the incoming request. - - Args: - request: Raw request data containing generation parameters - - Returns: - Dict[str, Any]: Processed request parameters including prompt, negative_prompt, - num_images, num_inference_steps, guidance_scale, and mode - """ + """Decode and validate the incoming request.""" return { "prompt": request["prompt"], "negative_prompt": request.get("negative_prompt", ""), @@ -59,74 +67,107 @@ def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: "mode": request.get("mode", "s3_json") } -def generate_images(params: Dict[str, Any]) -> List[Dict[str, Any]]: - """ - Generate images using the SDXL pipeline. - - Args: - params: Generation parameters including prompt, negative_prompt, num_images, - num_inference_steps, and guidance_scale - - Returns: - List[Dict[str, Any]]: List of dictionaries containing generated images and their modes - """ - images = pipeline( +async def generate_images(params: Dict[str, Any], pipeline: DiffusionPipeline) -> List[Dict[str, Any]]: + """Generate images using the SDXL pipeline asynchronously.""" + images = await asyncio.to_thread( + pipeline, prompt=params["prompt"], negative_prompt=params["negative_prompt"], num_images_per_prompt=params["num_images"], num_inference_steps=params["num_inference_steps"], guidance_scale=params["guidance_scale"], - ).images - - return [{"image": img, "mode": params["mode"]} for img in images] - -def encode_response(output: Dict[str, Any]) -> Dict[str, Any]: - """ - Encode the generated image based on the specified mode. + ) - Args: - output: Dictionary containing image and mode - - Returns: - Dict[str, Any]: Encoded response either as S3 URL or base64 string - - Raises: - ValueError: If the specified mode is not supported - """ + return [{"image": img, "mode": params["mode"]} for img in images.images] + +async def encode_response(output: Dict[str, Any]) -> Dict[str, Any]: + """Encode the generated image asynchronously.""" mode = output["mode"] image = output["image"] if mode == "s3_json": - return pil_to_s3_json(image, "sdxl_image") + return await asyncio.to_thread(pil_to_s3_json, image, "sdxl_image") elif mode == "b64_json": - return pil_to_b64_json(image) + return await asyncio.to_thread(pil_to_b64_json, image) else: raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.") -def handler(job: Dict[str, Any]) -> Dict[str, Any]: +async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: """ - RunPod handler function for processing image generation requests. - - Args: - job: RunPod job dictionary containing input parameters - - Returns: - Dict[str, Any]: Generated images in requested format or error message - Returns single result for one image or list of results for multiple images + Async generator handler for RunPod with progress updates. """ try: - params = decode_request(job['input']) - outputs = generate_images(params) - results = [encode_response(output) for output in outputs] - - if len(results) == 1: - return results[0] - return {"results": results} - + # Initial status + yield {"status": "starting", "message": "Initializing image generation process"} + + # Initialize pipeline + pipeline = await initialize_pipeline() + yield {"status": "processing", "message": "Pipeline loaded successfully"} + + # Decode request + try: + params = decode_request(job['input']) + yield { + "status": "processing", + "message": "Request decoded successfully", + "params": { + "prompt": params["prompt"], + "num_images": params["num_images"], + "steps": params["num_inference_steps"] + } + } + except Exception as e: + logger.error(f"Request decode error: {e}") + yield {"status": "error", "message": f"Error decoding request: {str(e)}"} + return + + # Generate images + try: + yield {"status": "processing", "message": "Generating images"} + outputs = await generate_images(params, pipeline) + yield {"status": "processing", "message": f"Generated {len(outputs)} images successfully"} + except Exception as e: + logger.error(f"Generation error: {e}") + yield {"status": "error", "message": f"Error generating images: {str(e)}"} + return + + # Encode responses + try: + yield {"status": "processing", "message": "Encoding and uploading images"} + results = [] + for idx, output in enumerate(outputs, 1): + result = await encode_response(output) + results.append(result) + yield { + "status": "processing", + "message": f"Processed image {idx}/{len(outputs)}" + } + except Exception as e: + logger.error(f"Encoding error: {e}") + yield {"status": "error", "message": f"Error encoding images: {str(e)}"} + return + + # Final response + final_response = results[0] if len(results) == 1 else {"results": results} + yield { + "status": "completed", + "output": final_response + } + except Exception as e: - return {"error": str(e)} + logger.error(f"Unexpected error: {e}") + yield { + "status": "error", + "message": f"Unexpected error: {str(e)}" + } + +# Initialize pipeline at startup +logger.info("Initializing service...") +asyncio.get_event_loop().run_until_complete(initialize_pipeline()) +logger.info("Service initialization complete") if __name__ == "__main__": runpod.serverless.start({ - "handler": handler + "handler": async_generator_handler, + "return_aggregate_stream": True }) \ No newline at end of file