Skip to content

Commit

Permalink
split runpod service for improved model loading and response handing
Browse files Browse the repository at this point in the history
  • Loading branch information
VikramxD committed Oct 24, 2024
1 parent 38c2f87 commit 80955cb
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 350 deletions.
160 changes: 97 additions & 63 deletions serverless/image-to-video/run_image-to-video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import runpod
import tempfile
import time
from typing import Dict, Any
from typing import Dict, Any, List, Union, Tuple
from PIL import Image
import base64
from pydantic import BaseModel, Field
Expand All @@ -25,76 +25,110 @@ class ImageToVideoRequest(BaseModel):
use_dynamic_cfg: bool = Field(True, description="Use dynamic CFG")
fps: int = Field(30, description="Frames per second for the output video")

class RunPodHandler:
def __init__(self):
self.device = "cuda"
self.pipeline = ImageToVideoPipeline(device=self.device)
device = "cuda"
pipeline = ImageToVideoPipeline(device=device)

def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
try:
video_request = ImageToVideoRequest(**request)
image_data = base64.b64decode(video_request.image)
image = Image.open(io.BytesIO(image_data)).convert("RGB")

return {
'image': image,
'params': video_request.model_dump()
}
except Exception as e:
raise ValueError(f"Invalid request: {str(e)}")

def generate_video(self, job: Dict[str, Any]) -> Dict[str, Any]:
try:
inputs = self.decode_request(job['input'])
image = inputs['image']
params = inputs['params']
def decode_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Decode and validate the incoming video generation request.
Args:
request: Raw request data containing base64 image and parameters
Returns:
Dict containing decoded PIL Image and validated parameters
Raises:
ValueError: If request validation or image decoding fails
"""
try:
video_request = ImageToVideoRequest(**request)
image_data = base64.b64decode(video_request.image)
image = Image.open(io.BytesIO(image_data)).convert("RGB")

return {
'image': image,
'params': video_request.model_dump()
}
except Exception as e:
raise ValueError(f"Invalid request: {str(e)}")

start_time = time.time()
def generate_frames(inputs: Dict[str, Any]) -> Tuple[Union[List[Image.Image], Image.Image], float]:
"""
Generate video frames using the pipeline.
Args:
inputs: Dictionary containing input image and generation parameters
Returns:
Tuple containing generated frames and completion time
"""
start_time = time.time()

frames = pipeline.generate(
prompt=inputs['params']['prompt'],
image=inputs['image'],
num_frames=inputs['params']['num_frames'],
num_inference_steps=inputs['params']['num_inference_steps'],
guidance_scale=inputs['params']['guidance_scale'],
height=inputs['params']['height'],
width=inputs['params']['width'],
use_dynamic_cfg=inputs['params']['use_dynamic_cfg']
)

if isinstance(frames, tuple):
frames = frames[0]
elif hasattr(frames, 'frames'):
frames = frames.frames[0]

completion_time = time.time() - start_time
return frames, completion_time

# Generate frames using the pipeline
frames = self.pipeline.generate(
prompt=params['prompt'],
image=image,
num_frames=params['num_frames'],
num_inference_steps=params['num_inference_steps'],
guidance_scale=params['guidance_scale'],
height=params['height'],
width=params['width'],
use_dynamic_cfg=params['use_dynamic_cfg']
def create_video_response(frames: List[Image.Image], completion_time: float, fps: int) -> Dict[str, Any]:
"""
Create video file and generate response with S3 URL.
Args:
frames: List of generated video frames
completion_time: Time taken for generation
fps: Frames per second for the output video
Returns:
Dict containing S3 URL, completion time, and video metadata
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_video_path = os.path.join(temp_dir, "generated_video.mp4")
export_to_video(frames, temp_video_path, fps=fps)

with open(temp_video_path, "rb") as video_file:
s3_response = mp4_to_s3_json(
video_file,
f"generated_video_{int(time.time())}.mp4"
)

if isinstance(frames, tuple):
frames = frames[0]
elif hasattr(frames, 'frames'):
frames = frames.frames[0]

completion_time = time.time() - start_time
fps = params['fps']

# Create temporary video file and upload to S3
with tempfile.TemporaryDirectory() as temp_dir:
temp_video_path = os.path.join(temp_dir, "generated_video.mp4")
export_to_video(frames, temp_video_path, fps=fps)

with open(temp_video_path, "rb") as video_file:
s3_response = mp4_to_s3_json(video_file, f"generated_video_{int(time.time())}.mp4")

return {
"result": s3_response,
"completion_time": round(completion_time, 2),
"video_resolution": f"{frames[0].width}x{frames[0].height}",
"fps": fps
}

except Exception as e:
return {"error": str(e)}
return {
"result": s3_response,
"completion_time": round(completion_time, 2),
"video_resolution": f"{frames[0].width}x{frames[0].height}",
"fps": fps
}

def handler(job):
def handler(job: Dict[str, Any]) -> Dict[str, Any]:
"""
RunPod handler function.
RunPod handler function for processing video generation requests.
Args:
job: RunPod job dictionary containing input data
Returns:
Dict containing either the processed video result or error information
"""
handler = RunPodHandler()
return handler.generate_video(job)
try:
inputs = decode_request(job['input'])
frames, completion_time = generate_frames(inputs)
return create_video_response(frames, completion_time, inputs['params']['fps'])
except Exception as e:
return {"error": str(e)}

if __name__ == "__main__":
runpod.serverless.start({
Expand Down
Loading

0 comments on commit 80955cb

Please sign in to comment.