From ed30992565c10bdf7d34ca71b76a6652d3d79c44 Mon Sep 17 00:00:00 2001 From: Vikramjeet Date: Fri, 25 Oct 2024 19:02:51 +0530 Subject: [PATCH] Refactor OutpaintingService for improved performance and readability --- serverless/outpainting/run_outpainting.py | 179 +++++++++++++--------- 1 file changed, 108 insertions(+), 71 deletions(-) diff --git a/serverless/outpainting/run_outpainting.py b/serverless/outpainting/run_outpainting.py index 3a266fb..9362751 100644 --- a/serverless/outpainting/run_outpainting.py +++ b/serverless/outpainting/run_outpainting.py @@ -11,6 +11,9 @@ class OutpaintingRequest(BaseModel): """ Pydantic model representing a request for outpainting inference. + + This model defines the structure and validation rules for incoming API requests. + All fields are required unless otherwise specified. """ image: str = Field(..., description="Base64 encoded input image") width: int = Field(1024, description="Target width") @@ -26,82 +29,115 @@ class OutpaintingRequest(BaseModel): overlap_top: bool = Field(True, description="Apply overlap on top side") overlap_bottom: bool = Field(True, description="Apply overlap on bottom side") -async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: +class OutpaintingService: """ - Decode and validate the incoming request asynchronously. + Service class for handling outpainting operations. + Based on LitAPI implementation but adapted for RunPod. """ - try: - outpainting_request = OutpaintingRequest(**request) - image_data = await asyncio.to_thread( - base64.b64decode, outpainting_request.image - ) - image = await asyncio.to_thread( - lambda: Image.open(io.BytesIO(image_data)).convert("RGBA") + + def __init__(self, device: str = "cuda"): + """Initialize the outpainting service.""" + self.device = device + self.outpainter = Outpainter() + + async def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]: + """ + Decode the incoming request and prepare inputs for the model. + + Args: + request: The raw request data. + + Returns: + Dict containing decoded image and request parameters. + + Raises: + ValueError: If request is invalid or cannot be processed. + """ + try: + outpainting_request = OutpaintingRequest(**request) + # Run decode in thread pool + image_data = await asyncio.to_thread( + base64.b64decode, outpainting_request.image + ) + image = await asyncio.to_thread( + lambda: Image.open(io.BytesIO(image_data)).convert("RGBA") + ) + + return { + 'image': image, + 'params': outpainting_request.model_dump() + } + except Exception as e: + raise ValueError(f"Invalid request: {str(e)}") + + async def predict(self, inputs: Dict[str, Any]) -> Tuple[Image.Image, float]: + """ + Run predictions on the input. + + Args: + inputs: Dict containing image and outpainting parameters. + + Returns: + Tuple containing the resulting image and completion time. + """ + image = inputs['image'] + params = inputs['params'] + + start_time = time.time() + + # Run outpainting in thread pool + result = await asyncio.to_thread( + self.outpainter.outpaint, + image, + params['width'], + params['height'], + params['overlap_percentage'], + params['num_inference_steps'], + params['resize_option'], + params['custom_resize_percentage'], + params['prompt_input'], + params['alignment'], + params['overlap_left'], + params['overlap_right'], + params['overlap_top'], + params['overlap_bottom'] ) + + completion_time = time.time() - start_time + return result, completion_time + + async def encode_response(self, output: Tuple[Image.Image, float]) -> Dict[str, Any]: + """ + Encode the model output into a response payload. + + Args: + output: Tuple containing outpainted image and completion time. + + Returns: + Dict containing S3 URL and metadata. + """ + image, completion_time = output + # Run S3 upload in thread pool + img_str = await asyncio.to_thread(pil_to_s3_json, image, "outpainting_image") return { - 'image': image, - 'params': outpainting_request.model_dump() + "result": img_str, + "completion_time": round(completion_time, 2), + "image_resolution": f"{image.width}x{image.height}" } - except Exception as e: - raise ValueError(f"Invalid request: {str(e)}") - -async def generate_outpainting(inputs: Dict[str, Any]) -> Tuple[Image.Image, float]: - """ - Perform outpainting operation asynchronously. - """ - start_time = time.time() - - # Initialize Outpainter for each request - outpainter = Outpainter() - - result = await asyncio.to_thread( - outpainter.outpaint, - inputs['image'], - inputs['params']['width'], - inputs['params']['height'], - inputs['params']['overlap_percentage'], - inputs['params']['num_inference_steps'], - inputs['params']['resize_option'], - inputs['params']['custom_resize_percentage'], - inputs['params']['prompt_input'], - inputs['params']['alignment'], - inputs['params']['overlap_left'], - inputs['params']['overlap_right'], - inputs['params']['overlap_top'], - inputs['params']['overlap_bottom'] - ) - - completion_time = time.time() - start_time - return result, completion_time - -async def format_response(result: Image.Image, completion_time: float) -> Dict[str, Any]: - """ - Format the outpainting result for API response asynchronously. - """ - img_str = await asyncio.to_thread( - pil_to_s3_json, - result, - "outpainting_image" - ) - - return { - "result": img_str, - "completion_time": round(completion_time, 2), - "image_resolution": f"{result.width}x{result.height}" - } async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: """ Async generator handler for RunPod with progress updates. """ try: - # Initial status - yield {"status": "starting", "message": "Starting outpainting process"} + # Create service instance + service = OutpaintingService(device="cuda") + yield {"status": "starting", "message": "Service initialized"} # Decode request try: - inputs = await decode_request(job['input']) + inputs = await service.decode_request(job['input']) yield { "status": "processing", "message": "Request decoded successfully", @@ -111,26 +147,26 @@ async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[st yield {"status": "error", "message": f"Error decoding request: {str(e)}"} return - # Generate outpainting + # Generate prediction try: - yield {"status": "processing", "message": "Initializing outpainting model"} - result, completion_time = await generate_outpainting(inputs) + yield {"status": "processing", "message": "Starting outpainting"} + result = await service.predict(inputs) yield { - "status": "processing", + "status": "processing", "message": "Outpainting completed", - "completion_time": f"{completion_time:.2f}s" + "completion_time": f"{result[1]:.2f}s" } except Exception as e: yield {"status": "error", "message": f"Error during outpainting: {str(e)}"} return - # Format and upload result + # Encode response try: - yield {"status": "processing", "message": "Uploading result"} - response = await format_response(result, completion_time) - yield {"status": "processing", "message": "Result uploaded successfully"} + yield {"status": "processing", "message": "Encoding result"} + response = await service.encode_response(result) + yield {"status": "processing", "message": "Result encoded successfully"} except Exception as e: - yield {"status": "error", "message": f"Error uploading result: {str(e)}"} + yield {"status": "error", "message": f"Error encoding result: {str(e)}"} return # Final response @@ -146,6 +182,7 @@ async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[st } if __name__ == "__main__": + import runpod runpod.serverless.start({ "handler": async_generator_handler, "return_aggregate_stream": True