Skip to content

Commit

Permalink
Refactor OutpaintingService for improved performance and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
VikramxD committed Oct 25, 2024
1 parent 8a8e640 commit ed30992
Showing 1 changed file with 108 additions and 71 deletions.
179 changes: 108 additions & 71 deletions serverless/outpainting/run_outpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
class OutpaintingRequest(BaseModel):
"""
Pydantic model representing a request for outpainting inference.
This model defines the structure and validation rules for incoming API requests.
All fields are required unless otherwise specified.
"""
image: str = Field(..., description="Base64 encoded input image")
width: int = Field(1024, description="Target width")
Expand All @@ -26,82 +29,115 @@ class OutpaintingRequest(BaseModel):
overlap_top: bool = Field(True, description="Apply overlap on top side")
overlap_bottom: bool = Field(True, description="Apply overlap on bottom side")

async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]:
class OutpaintingService:
"""
Decode and validate the incoming request asynchronously.
Service class for handling outpainting operations.
Based on LitAPI implementation but adapted for RunPod.
"""
try:
outpainting_request = OutpaintingRequest(**request)
image_data = await asyncio.to_thread(
base64.b64decode, outpainting_request.image
)
image = await asyncio.to_thread(
lambda: Image.open(io.BytesIO(image_data)).convert("RGBA")

def __init__(self, device: str = "cuda"):
"""Initialize the outpainting service."""
self.device = device
self.outpainter = Outpainter()

async def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Decode the incoming request and prepare inputs for the model.
Args:
request: The raw request data.
Returns:
Dict containing decoded image and request parameters.
Raises:
ValueError: If request is invalid or cannot be processed.
"""
try:
outpainting_request = OutpaintingRequest(**request)
# Run decode in thread pool
image_data = await asyncio.to_thread(
base64.b64decode, outpainting_request.image
)
image = await asyncio.to_thread(
lambda: Image.open(io.BytesIO(image_data)).convert("RGBA")
)

return {
'image': image,
'params': outpainting_request.model_dump()
}
except Exception as e:
raise ValueError(f"Invalid request: {str(e)}")

async def predict(self, inputs: Dict[str, Any]) -> Tuple[Image.Image, float]:
"""
Run predictions on the input.
Args:
inputs: Dict containing image and outpainting parameters.
Returns:
Tuple containing the resulting image and completion time.
"""
image = inputs['image']
params = inputs['params']

start_time = time.time()

# Run outpainting in thread pool
result = await asyncio.to_thread(
self.outpainter.outpaint,
image,
params['width'],
params['height'],
params['overlap_percentage'],
params['num_inference_steps'],
params['resize_option'],
params['custom_resize_percentage'],
params['prompt_input'],
params['alignment'],
params['overlap_left'],
params['overlap_right'],
params['overlap_top'],
params['overlap_bottom']
)

completion_time = time.time() - start_time
return result, completion_time

async def encode_response(self, output: Tuple[Image.Image, float]) -> Dict[str, Any]:
"""
Encode the model output into a response payload.
Args:
output: Tuple containing outpainted image and completion time.
Returns:
Dict containing S3 URL and metadata.
"""
image, completion_time = output
# Run S3 upload in thread pool
img_str = await asyncio.to_thread(pil_to_s3_json, image, "outpainting_image")

return {
'image': image,
'params': outpainting_request.model_dump()
"result": img_str,
"completion_time": round(completion_time, 2),
"image_resolution": f"{image.width}x{image.height}"
}
except Exception as e:
raise ValueError(f"Invalid request: {str(e)}")

async def generate_outpainting(inputs: Dict[str, Any]) -> Tuple[Image.Image, float]:
"""
Perform outpainting operation asynchronously.
"""
start_time = time.time()

# Initialize Outpainter for each request
outpainter = Outpainter()

result = await asyncio.to_thread(
outpainter.outpaint,
inputs['image'],
inputs['params']['width'],
inputs['params']['height'],
inputs['params']['overlap_percentage'],
inputs['params']['num_inference_steps'],
inputs['params']['resize_option'],
inputs['params']['custom_resize_percentage'],
inputs['params']['prompt_input'],
inputs['params']['alignment'],
inputs['params']['overlap_left'],
inputs['params']['overlap_right'],
inputs['params']['overlap_top'],
inputs['params']['overlap_bottom']
)

completion_time = time.time() - start_time
return result, completion_time

async def format_response(result: Image.Image, completion_time: float) -> Dict[str, Any]:
"""
Format the outpainting result for API response asynchronously.
"""
img_str = await asyncio.to_thread(
pil_to_s3_json,
result,
"outpainting_image"
)

return {
"result": img_str,
"completion_time": round(completion_time, 2),
"image_resolution": f"{result.width}x{result.height}"
}

async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
"""
Async generator handler for RunPod with progress updates.
"""
try:
# Initial status
yield {"status": "starting", "message": "Starting outpainting process"}
# Create service instance
service = OutpaintingService(device="cuda")
yield {"status": "starting", "message": "Service initialized"}

# Decode request
try:
inputs = await decode_request(job['input'])
inputs = await service.decode_request(job['input'])
yield {
"status": "processing",
"message": "Request decoded successfully",
Expand All @@ -111,26 +147,26 @@ async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[st
yield {"status": "error", "message": f"Error decoding request: {str(e)}"}
return

# Generate outpainting
# Generate prediction
try:
yield {"status": "processing", "message": "Initializing outpainting model"}
result, completion_time = await generate_outpainting(inputs)
yield {"status": "processing", "message": "Starting outpainting"}
result = await service.predict(inputs)
yield {
"status": "processing",
"status": "processing",
"message": "Outpainting completed",
"completion_time": f"{completion_time:.2f}s"
"completion_time": f"{result[1]:.2f}s"
}
except Exception as e:
yield {"status": "error", "message": f"Error during outpainting: {str(e)}"}
return

# Format and upload result
# Encode response
try:
yield {"status": "processing", "message": "Uploading result"}
response = await format_response(result, completion_time)
yield {"status": "processing", "message": "Result uploaded successfully"}
yield {"status": "processing", "message": "Encoding result"}
response = await service.encode_response(result)
yield {"status": "processing", "message": "Result encoded successfully"}
except Exception as e:
yield {"status": "error", "message": f"Error uploading result: {str(e)}"}
yield {"status": "error", "message": f"Error encoding result: {str(e)}"}
return

# Final response
Expand All @@ -146,6 +182,7 @@ async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[st
}

if __name__ == "__main__":
import runpod
runpod.serverless.start({
"handler": async_generator_handler,
"return_aggregate_stream": True
Expand Down

0 comments on commit ed30992

Please sign in to comment.