Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add object detection pipeline #243

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3905b84
feat:initial implementation of object-detection pipeline
RUFFY-369 Oct 27, 2024
9f1268d
style:ruff format
RUFFY-369 Oct 27, 2024
a760e05
Merge remote-tracking branch 'upstream/main' into feature/video-to-text
RUFFY-369 Oct 27, 2024
fe8caa0
fix:bug in pipeline inference
RUFFY-369 Oct 28, 2024
5e15705
chore:regenerate openapi specification
RUFFY-369 Oct 28, 2024
5131e2a
fix:resolve merge conflicts
RUFFY-369 Oct 28, 2024
1c0b66e
chore:make codegen
RUFFY-369 Oct 29, 2024
5097806
fix:resolve merge conflicts
RUFFY-369 Oct 29, 2024
b8f6c44
chore:make codegen
RUFFY-369 Oct 29, 2024
9ee510e
fix:internal server error
RUFFY-369 Oct 30, 2024
38d3dc3
Merge remote-tracking branch 'upstream/main' into feature/object-dete…
RUFFY-369 Oct 30, 2024
bc331ad
fix:resolve merge conflicts
RUFFY-369 Oct 31, 2024
63328bc
chore:make codegen
RUFFY-369 Oct 31, 2024
7ec7bb0
fix:bug in go-livepeer local build
RUFFY-369 Nov 2, 2024
bbbed92
chore:fix merge conflicts
RUFFY-369 Nov 2, 2024
60e5d35
chore:make codegen
RUFFY-369 Nov 2, 2024
f4439dc
chore:fix merge conflicts
RUFFY-369 Nov 14, 2024
444c32b
chore:make codegen
RUFFY-369 Nov 14, 2024
0b48da6
chore:suggested update in object detection runner
RUFFY-369 Nov 22, 2024
927eaaf
chore:fix merge conflicts
RUFFY-369 Nov 22, 2024
5785333
chore:make codegen
RUFFY-369 Nov 22, 2024
cceb64f
chore:return a base64 encoded video file instead of url for each frames
RUFFY-369 Nov 30, 2024
fc423ee
chore:add new parameter for optionality of annotated video frames
RUFFY-369 Nov 30, 2024
dc61b5f
chore:add detection boxes to ObjecDetectionResponse
RUFFY-369 Nov 30, 2024
41d658a
chore:add frames pts to ObjectDetectionResponse
RUFFY-369 Nov 30, 2024
fcd3125
fix:suggested solution for possible test videos failure
RUFFY-369 Nov 30, 2024
dd77858
chore:make codegen
RUFFY-369 Dec 2, 2024
905dcb8
fix:build error for go livepeer side
RUFFY-369 Dec 2, 2024
2b92bc2
fix:logical error in detection boxes return for all frames
RUFFY-369 Dec 3, 2024
fc65457
updates to ai-runner object detection pipeline
ad-astra-video Jan 2, 2025
728e15d
fix annotate
ad-astra-video Jan 2, 2025
8e5f4ee
Merge pull request #1 from ad-astra-video/object-detection
RUFFY-369 Jan 5, 2025
450ddf4
chore:fix merge conflicts
RUFFY-369 Jan 5, 2025
f6903c9
chore:fix previous commit wrong code for merge conflict resolution
RUFFY-369 Jan 5, 2025
f94933d
chore:make codegen
RUFFY-369 Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.text_to_speech import TextToSpeechPipeline

return TextToSpeechPipeline(model_id)
case "object-detection":
from app.pipelines.object_detection import ObjectDetectionPipeline

return ObjectDetectionPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -128,6 +132,9 @@ def load_route(pipeline: str) -> any:
from app.routes import text_to_speech

return text_to_speech.router
case "object-detection":
from app.routes import object_detection
return object_detection.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
133 changes: 133 additions & 0 deletions runner/app/pipelines/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
import os

import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device, DetectionFrame
from huggingface_hub import file_download
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from typing import List
from PIL import Image, ImageDraw, ImageFont
from app.utils.errors import InferenceError

logger = logging.getLogger(__name__)

def annotate_image(input_image, detections, labels, font_size, font):
draw = ImageDraw.Draw(input_image)
bounding_box_color = (255, 255, 0) # Bright Yellow for bounding box
text_color = (0, 0, 0) # Black for text
for box, label in zip(detections["boxes"], labels):
x1, y1, x2, y2 = map(int, box)
draw.rectangle([x1, y1, x2, y2], outline=bounding_box_color, width=3)
# Place label above the bounding box
draw.text((x1, y1 - font_size - 5), label, fill=text_color, font=font) # Adjust y position
return input_image



class ObjectDetectionPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {}

self.torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load fp16 variant if fp16 safetensors files are found in cache
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if self.torch_device != "cpu" and has_fp16_variant:
logger.info("ObjectDetectionPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("ObjectDetectionPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

self.object_detection_model = AutoModelForObjectDetection.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=get_model_dir(),
**kwargs,
).to(self.torch_device)

#try to use fast image processor if possible
# new fast image processors are being added slowly to transformers, fallback to the default one if not available
try:
self.image_processor = AutoImageProcessor.from_pretrained(
model_id,
cache_dir=get_model_dir(),
use_fast=True
)
except:
self.image_processor = AutoImageProcessor.from_pretrained(
model_id,
cache_dir=get_model_dir()
)

# Load a font (default font is used here; you can specify your own path for a TTF file)
self.font_size = 24
self.font = ImageFont.load_default(size=self.font_size)


def __call__(self, frames: List[DetectionFrame], confidence_threshold: float = 0.6, return_annotated_video: bool = False, **kwargs) -> str:
try:
annotated_frames = []
confidence_scores_all_frames = []
labels_all_frames = []
detection_boxes_all_frames = []
pts_all_frames = []

for frame in frames:
# Process frame and add annotations
inputs = self.image_processor(images=frame.image, return_tensors="pt").to(self.torch_device)
with torch.no_grad():
outputs = self.object_detection_model(**inputs)

target_sizes = torch.tensor([frame.image.size[::-1]])
results = self.image_processor.post_process_object_detection(
outputs=outputs,
threshold=confidence_threshold,
target_sizes=target_sizes
)[0]

final_labels = []
confidence_scores = []

detections = {"boxes": results["boxes"].cpu().numpy()}

for label_id, score in zip(results["labels"].cpu().numpy(),results["scores"].cpu().numpy()):
final_labels.append(self.object_detection_model.config.id2label[label_id])
confidence_scores.append(round(score, 3))

if return_annotated_video:
annotated_frame = annotate_image(
input_image=frame.image,
detections=detections,
labels=final_labels,
font_size=self.font_size,
font=self.font
)
annotated_frames.append(annotated_frame)

# List of detections, confidence scores, labels, and pts for each frame
confidence_scores_all_frames.append(confidence_scores)
labels_all_frames.append(final_labels)
detection_boxes_all_frames.append(detections["boxes"].tolist())
pts_all_frames.append(float(frame.pts * frame.time_base))

return annotated_frames, confidence_scores_all_frames, labels_all_frames, detection_boxes_all_frames, pts_all_frames

except Exception as e:
raise InferenceError(original_exception=e)

def __str__(self) -> str:
return f"ObjectDetectionPipeline model_id={self.model_id}"
1 change: 1 addition & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
split_prompt,
validate_torch_device,
get_max_memory,
DetectionFrame
)
8 changes: 8 additions & 0 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from PIL import Image
from torch import dtype as TorchDtype
from transformers import CLIPImageProcessor
from dataclasses import dataclass

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -388,3 +389,10 @@ def get_max_memory() -> MemoryInfo:
cpu_memory=cpu_memory, num_gpus=num_gpus)

return memory_info


@dataclass
class DetectionFrame:
pts: float
time_base: float
image: Image
163 changes: 163 additions & 0 deletions runner/app/routes/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import av
import logging
import os
from typing import Annotated, Dict, Tuple, Union
import time

import torch

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import (
HTTPError,
ObjectDetectionResponse,
file_exceeds_max_size,
handle_pipeline_exception,
http_error,
frames_to_video_data_url,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import ImageFile
from app.pipelines.utils import DetectionFrame

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"OutOfMemoryError": (
"Out of memory error. Try reducing input image resolution.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
"application/json": {
"schema": {
"x-speakeasy-name-override": "data",
}
}
},
},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}

@router.post(
"/object-detection",
response_model=ObjectDetectionResponse,
responses=RESPONSES,
description="Generate annotated video(s) for object detection from the input video(s)",
operation_id="genObjectDetection",
summary="Object Detection",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "objectDetection"},
)
@router.post(
"/object-detection/",
response_model=ObjectDetectionResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def object_detection(
video: Annotated[
UploadFile, File(description="Uploaded video to transform with the pipeline.")
],
confidence_threshold: Annotated[
float, Form(description="Score threshold to keep object detection predictions.")
] = 0.6,
model_id: Annotated[
str,
Form(description="Hugging Face model ID used for transformation."),
] = "",
return_annotated_video: Annotated[
bool,
Form(description="If true, returns annotated video url."),
] = False,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

if file_exceeds_max_size(video, 50 * 1024 * 1024):
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content=http_error("File size exceeds limit"),
)

frames = []
try:
container = av.open(video.file, mode='r')
stream = container.streams.video[0]
fps = stream.average_rate

start = time.time()
for frame in container.decode(video=0): # Decode video frames
# Convert each frame to PIL image and add to list
frames.append(DetectionFrame(pts=frame.pts,
time_base=stream.time_base,
image=frame.to_image()
)
)

container.close()
logger.info(f"Decoded video in {time.time() - start:.2f} seconds")

start = time.time()
annotated_frames, confidence_scores_all_frames, labels_all_frames, detection_boxes, pts_of_detections = pipeline(
frames=frames,
confidence_threshold=confidence_threshold,
return_annotated_video=return_annotated_video,
)
logger.info(f"Detections processed in {time.time() - start:.2f} seconds")
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"ObjectDetectionPipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Object-detection pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)

# Convert the annotated frames to a data url
if return_annotated_video:
start = time.time()
encoded_frames_url = frames_to_video_data_url(annotated_frames, fps=fps)
logger.info(f"Annotated frames converted to data URL in {time.time() - start:.2f} seconds, frame count: {len(annotated_frames)}")
else:
encoded_frames_url = ""

return {
"video": {"url": encoded_frames_url},
"confidence_scores": str(confidence_scores_all_frames),
"labels": str(labels_all_frames),
"detection_boxes":str(detection_boxes),
"detection_pts":str(pts_of_detections),
}
Loading
Loading