Skip to content

Commit

Permalink
Refactor outpainting service for improved performance and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
VikramxD committed Oct 25, 2024
1 parent 749331c commit 8a8e640
Showing 1 changed file with 88 additions and 85 deletions.
173 changes: 88 additions & 85 deletions serverless/outpainting/run_outpainting.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import io
import base64
import time
import runpod
import asyncio
from typing import Dict, Any, AsyncGenerator
from typing import Dict, Any, Tuple, AsyncGenerator
from PIL import Image
from pydantic import BaseModel, Field
from scripts.outpainting import Outpainter
from scripts.api_utils import pil_to_s3_json

# Global cache for the Outpainter instance
global_outpainter = None

class OutpaintingRequest(BaseModel):
"""
Pydantic model representing a request for outpainting inference.
Expand All @@ -30,118 +26,125 @@ class OutpaintingRequest(BaseModel):
overlap_top: bool = Field(True, description="Apply overlap on top side")
overlap_bottom: bool = Field(True, description="Apply overlap on bottom side")

async def initialize_model():
"""Initialize the model if not already loaded"""
global global_outpainter
if global_outpainter is None:
print("Initializing Outpainter model...")
global_outpainter = Outpainter()
print("Model initialized successfully")
return global_outpainter
async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Decode and validate the incoming request asynchronously.
"""
try:
outpainting_request = OutpaintingRequest(**request)
image_data = await asyncio.to_thread(
base64.b64decode, outpainting_request.image
)
image = await asyncio.to_thread(
lambda: Image.open(io.BytesIO(image_data)).convert("RGBA")
)

return {
'image': image,
'params': outpainting_request.model_dump()
}
except Exception as e:
raise ValueError(f"Invalid request: {str(e)}")

async def generate_outpainting(inputs: Dict[str, Any]) -> Tuple[Image.Image, float]:
"""
Perform outpainting operation asynchronously.
"""
start_time = time.time()

# Initialize Outpainter for each request
outpainter = Outpainter()

result = await asyncio.to_thread(
outpainter.outpaint,
inputs['image'],
inputs['params']['width'],
inputs['params']['height'],
inputs['params']['overlap_percentage'],
inputs['params']['num_inference_steps'],
inputs['params']['resize_option'],
inputs['params']['custom_resize_percentage'],
inputs['params']['prompt_input'],
inputs['params']['alignment'],
inputs['params']['overlap_left'],
inputs['params']['overlap_right'],
inputs['params']['overlap_top'],
inputs['params']['overlap_bottom']
)

completion_time = time.time() - start_time
return result, completion_time

async def format_response(result: Image.Image, completion_time: float) -> Dict[str, Any]:
"""
Format the outpainting result for API response asynchronously.
"""
img_str = await asyncio.to_thread(
pil_to_s3_json,
result,
"outpainting_image"
)

return {
"result": img_str,
"completion_time": round(completion_time, 2),
"image_resolution": f"{result.width}x{result.height}"
}

async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
"""
Async generator handler for RunPod.
Yields status updates and progress during the outpainting process.
Async generator handler for RunPod with progress updates.
"""
try:
# Initial status
yield {"status": "starting", "message": "Initializing outpainting process"}

# Initialize model
outpainter = await initialize_model()
yield {"status": "processing", "message": "Model loaded successfully"}
yield {"status": "starting", "message": "Starting outpainting process"}

# Decode request
try:
request = OutpaintingRequest(**job['input'])
image_data = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_data)).convert("RGBA")
yield {"status": "processing", "message": "Request decoded successfully"}
inputs = await decode_request(job['input'])
yield {
"status": "processing",
"message": "Request decoded successfully",
"input_resolution": f"{inputs['image'].width}x{inputs['image'].height}"
}
except Exception as e:
yield {"status": "error", "message": f"Error decoding request: {str(e)}"}
return

# Start timing
start_time = time.time()

# Prepare outpainting parameters
inputs = {
'image': image,
'params': {
'width': request.width,
'height': request.height,
'overlap_percentage': request.overlap_percentage,
'num_inference_steps': request.num_inference_steps,
'resize_option': request.resize_option,
'custom_resize_percentage': request.custom_resize_percentage,
'prompt_input': request.prompt_input,
'alignment': request.alignment,
'overlap_left': request.overlap_left,
'overlap_right': request.overlap_right,
'overlap_top': request.overlap_top,
'overlap_bottom': request.overlap_bottom
}
}

yield {"status": "processing", "message": "Starting outpainting process"}

# Perform outpainting
# Generate outpainting
try:
result = outpainter.outpaint(
inputs['image'],
inputs['params']['width'],
inputs['params']['height'],
inputs['params']['overlap_percentage'],
inputs['params']['num_inference_steps'],
inputs['params']['resize_option'],
inputs['params']['custom_resize_percentage'],
inputs['params']['prompt_input'],
inputs['params']['alignment'],
inputs['params']['overlap_left'],
inputs['params']['overlap_right'],
inputs['params']['overlap_top'],
inputs['params']['overlap_bottom']
)
yield {"status": "processing", "message": "Outpainting completed successfully"}
yield {"status": "processing", "message": "Initializing outpainting model"}
result, completion_time = await generate_outpainting(inputs)
yield {
"status": "processing",
"message": "Outpainting completed",
"completion_time": f"{completion_time:.2f}s"
}
except Exception as e:
yield {"status": "error", "message": f"Error during outpainting: {str(e)}"}
return

# Upload to S3 and get URL
# Format and upload result
try:
img_str = pil_to_s3_json(result, "outpainting_image")
yield {"status": "processing", "message": "Image uploaded successfully"}
yield {"status": "processing", "message": "Uploading result"}
response = await format_response(result, completion_time)
yield {"status": "processing", "message": "Result uploaded successfully"}
except Exception as e:
yield {"status": "error", "message": f"Error uploading image: {str(e)}"}
yield {"status": "error", "message": f"Error uploading result: {str(e)}"}
return

# Calculate completion time
completion_time = time.time() - start_time

# Final response
final_response = {
yield {
"status": "completed",
"output": {
"result": img_str,
"completion_time": round(completion_time, 2),
"image_resolution": f"{result.width}x{result.height}"
}
"output": response
}

yield final_response

except Exception as e:
yield {
"status": "error",
"message": f"Unexpected error: {str(e)}"
}

# Initialize the model when the service starts
print("Initializing service...")
asyncio.get_event_loop().run_until_complete(initialize_model())
print("Service initialization complete")

if __name__ == "__main__":
runpod.serverless.start({
"handler": async_generator_handler,
Expand Down

0 comments on commit 8a8e640

Please sign in to comment.