Skip to content

Commit

Permalink
asynchronous runpod handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
VikramxD committed Oct 25, 2024
1 parent 80955cb commit 749331c
Show file tree
Hide file tree
Showing 4 changed files with 470 additions and 290 deletions.
162 changes: 105 additions & 57 deletions serverless/image-to-video/run_image-to-video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import runpod
import tempfile
import time
from typing import Dict, Any, List, Union, Tuple
import asyncio
from typing import Dict, Any, List, Union, Tuple, AsyncGenerator
from PIL import Image
import base64
from pydantic import BaseModel, Field
from diffusers.utils import export_to_video
from scripts.api_utils import mp4_to_s3_json
from scripts.image_to_video import ImageToVideoPipeline

# Global pipeline instance
global_pipeline = None

class ImageToVideoRequest(BaseModel):
"""
Pydantic model representing a request for image-to-video generation.
Expand All @@ -25,26 +29,26 @@ class ImageToVideoRequest(BaseModel):
use_dynamic_cfg: bool = Field(True, description="Use dynamic CFG")
fps: int = Field(30, description="Frames per second for the output video")

device = "cuda"
pipeline = ImageToVideoPipeline(device=device)
async def initialize_pipeline():
"""Initialize the pipeline if not already loaded"""
global global_pipeline
if global_pipeline is None:
print("Initializing Image to Video pipeline...")
global_pipeline = ImageToVideoPipeline(device="cuda")
print("Pipeline initialized successfully")
return global_pipeline

def decode_request(request: Dict[str, Any]) -> Dict[str, Any]:
async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Decode and validate the incoming video generation request.
Args:
request: Raw request data containing base64 image and parameters
Returns:
Dict containing decoded PIL Image and validated parameters
Raises:
ValueError: If request validation or image decoding fails
Decode and validate the incoming video generation request asynchronously.
"""
try:
video_request = ImageToVideoRequest(**request)
image_data = base64.b64decode(video_request.image)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Run decode in thread pool
image_data = await asyncio.to_thread(base64.b64decode, video_request.image)
image = await asyncio.to_thread(
lambda: Image.open(io.BytesIO(image_data)).convert("RGB")
)

return {
'image': image,
Expand All @@ -53,19 +57,15 @@ def decode_request(request: Dict[str, Any]) -> Dict[str, Any]:
except Exception as e:
raise ValueError(f"Invalid request: {str(e)}")

def generate_frames(inputs: Dict[str, Any]) -> Tuple[Union[List[Image.Image], Image.Image], float]:
async def generate_frames(inputs: Dict[str, Any], pipeline: ImageToVideoPipeline) -> Tuple[List[Image.Image], float]:
"""
Generate video frames using the pipeline.
Args:
inputs: Dictionary containing input image and generation parameters
Returns:
Tuple containing generated frames and completion time
Generate video frames using the pipeline asynchronously.
"""
start_time = time.time()

frames = pipeline.generate(
# Run generation in thread pool
frames = await asyncio.to_thread(
pipeline.generate,
prompt=inputs['params']['prompt'],
image=inputs['image'],
num_frames=inputs['params']['num_frames'],
Expand All @@ -84,27 +84,23 @@ def generate_frames(inputs: Dict[str, Any]) -> Tuple[Union[List[Image.Image], Im
completion_time = time.time() - start_time
return frames, completion_time

def create_video_response(frames: List[Image.Image], completion_time: float, fps: int) -> Dict[str, Any]:
async def create_video_response(frames: List[Image.Image], completion_time: float, fps: int) -> Dict[str, Any]:
"""
Create video file and generate response with S3 URL.
Args:
frames: List of generated video frames
completion_time: Time taken for generation
fps: Frames per second for the output video
Returns:
Dict containing S3 URL, completion time, and video metadata
Create video file and generate response with S3 URL asynchronously.
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_video_path = os.path.join(temp_dir, "generated_video.mp4")
export_to_video(frames, temp_video_path, fps=fps)

with open(temp_video_path, "rb") as video_file:
s3_response = mp4_to_s3_json(
video_file,
f"generated_video_{int(time.time())}.mp4"
)
def create_video():
with tempfile.TemporaryDirectory() as temp_dir:
temp_video_path = os.path.join(temp_dir, "generated_video.mp4")
export_to_video(frames, temp_video_path, fps=fps)

with open(temp_video_path, "rb") as video_file:
return mp4_to_s3_json(
video_file,
f"generated_video_{int(time.time())}.mp4"
)

# Run video creation and upload in thread pool
s3_response = await asyncio.to_thread(create_video)

return {
"result": s3_response,
Expand All @@ -113,24 +109,76 @@ def create_video_response(frames: List[Image.Image], completion_time: float, fps
"fps": fps
}

def handler(job: Dict[str, Any]) -> Dict[str, Any]:
async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
"""
RunPod handler function for processing video generation requests.
Args:
job: RunPod job dictionary containing input data
Returns:
Dict containing either the processed video result or error information
Async generator handler for RunPod with progress updates.
"""
try:
inputs = decode_request(job['input'])
frames, completion_time = generate_frames(inputs)
return create_video_response(frames, completion_time, inputs['params']['fps'])
# Initial status
yield {"status": "starting", "message": "Initializing video generation process"}

# Initialize pipeline
pipeline = await initialize_pipeline()
yield {"status": "processing", "message": "Pipeline loaded successfully"}

# Decode request
try:
inputs = await decode_request(job['input'])
yield {"status": "processing", "message": "Request decoded successfully"}
except Exception as e:
yield {"status": "error", "message": f"Error decoding request: {str(e)}"}
return

# Generate frames with progress updates
try:
yield {"status": "processing", "message": "Generating video frames"}
frames, completion_time = await generate_frames(inputs, pipeline)
yield {"status": "processing", "message": f"Generated {len(frames)} frames successfully"}
except Exception as e:
yield {"status": "error", "message": f"Error generating frames: {str(e)}"}
return

# Create and upload video
try:
yield {"status": "processing", "message": "Creating and uploading video"}
response = await create_video_response(
frames,
completion_time,
inputs['params']['fps']
)
yield {"status": "processing", "message": "Video uploaded successfully"}
except Exception as e:
yield {"status": "error", "message": f"Error creating video: {str(e)}"}
return

# Final response
yield {
"status": "completed",
"output": response
}

except Exception as e:
return {"error": str(e)}
yield {
"status": "error",
"message": f"Unexpected error: {str(e)}"
}

def calculate_progress(current_frame: int, total_frames: int) -> dict:
"""Calculate progress percentage and create status update."""
progress = (current_frame / total_frames) * 100
return {
"status": "processing",
"progress": round(progress, 2),
"message": f"Generating frame {current_frame}/{total_frames}"
}

# Initialize the pipeline when the service starts
print("Initializing service...")
asyncio.get_event_loop().run_until_complete(initialize_pipeline())
print("Service initialization complete")

if __name__ == "__main__":
runpod.serverless.start({
"handler": handler
"handler": async_generator_handler,
"return_aggregate_stream": True
})
Loading

0 comments on commit 749331c

Please sign in to comment.