Skip to content

Commit

Permalink
refactor(runner): add InferenceError to all pipelines (#188)
Browse files Browse the repository at this point in the history
This commit adds the inference error logic from the SAM2 pipeline to all pipelines so users are given a warning when they supply wrong arguments. It also improves the overal error handing behavoir.

Co-authored-by: gioelecerati <[email protected]>
  • Loading branch information
rickstaa and gioelecerati authored Oct 15, 2024
1 parent 0d96ac0 commit 40fa0c2
Show file tree
Hide file tree
Showing 27 changed files with 604 additions and 261 deletions.
8 changes: 7 additions & 1 deletion runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils.audio import AudioConverter
from app.utils.errors import InferenceError
from fastapi import File, UploadFile
from huggingface_hub import file_download
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
Expand Down Expand Up @@ -76,7 +77,12 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]:
converted_bytes = audio_converter.convert(audio, "mp3")
audio_converter.write_bytes_to_file(converted_bytes, audio)

return self.tm(audio.file.read(), **kwargs)
try:
outputs = self.tm(audio.file.read(), **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return outputs

def __str__(self) -> str:
return f"AudioToTextPipeline model_id={self.model_id}"
12 changes: 8 additions & 4 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_lightning_model,
is_turbo_model,
)
from app.utils.errors import InferenceError
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
Expand Down Expand Up @@ -233,14 +234,17 @@ def __call__(
# Default to 8step
kwargs["num_inference_steps"] = 8

output = self.ldm(prompt, image=image, **kwargs)
try:
outputs = self.ldm(prompt, image=image, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"
8 changes: 7 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile
Expand Down Expand Up @@ -135,7 +136,12 @@ def __call__(
else:
has_nsfw_concept = [None]

return self.ldm(image, **kwargs).frames, has_nsfw_concept
try:
outputs = self.ldm(image, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return outputs.frames, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"
2 changes: 1 addition & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def compile_model(pipe):
except ImportError:
logger.info("xformers not installed, skip")
try:
import triton # noqa: F401
import triton # noqa: F401

config.enable_triton = True
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import PIL
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from app.pipelines.utils import get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

Expand Down
12 changes: 8 additions & 4 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_turbo_model,
split_prompt,
)
from app.utils.errors import InferenceError
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
Expand Down Expand Up @@ -274,14 +275,17 @@ def __call__(
)
kwargs.update(neg_prompts)

output = self.ldm(prompt=prompt, **kwargs)
try:
outputs = self.ldm(prompt=prompt, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"
14 changes: 10 additions & 4 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
is_lightning_model,
is_turbo_model,
)
from app.utils.errors import InferenceError
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile
Expand Down Expand Up @@ -114,14 +115,19 @@ def __call__(
):
del kwargs["num_inference_steps"]

output = self.ldm(prompt, image=image, **kwargs)
try:
outputs = self.ldm(prompt, image=image, **kwargs)
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"UpscalePipeline model_id={self.model_id}"
2 changes: 2 additions & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from app.pipelines.utils.utils import (
LoraLoader,
LoraLoadingError,
SafetyChecker,
get_model_dir,
get_model_path,
get_torch_device,
is_lightning_model,
is_turbo_model,
is_numeric,
split_prompt,
validate_torch_device,
)
2 changes: 1 addition & 1 deletion runner/app/pipelines/utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class AudioConversionError(Exception):
"""Raised when an audio file cannot be converted."""

def __init__(self, message="Audio conversion failed."):
def __init__(self, message="Audio conversion failed"):
self.message = message
super().__init__(self.message)

Expand Down
67 changes: 35 additions & 32 deletions runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import os
from typing import Annotated
from typing import Annotated, Dict, Tuple, Union

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.audio import AudioConversionError
from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error
from app.routes.utils import (
HTTPError,
TextResponse,
file_exceeds_max_size,
http_error,
handle_pipeline_exception,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -14,6 +20,20 @@

logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration.
AUDIO_FORMAT_ERROR_MESSAGE = "Unsupported audio format or malformed file."
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"AudioConversionError": (
AUDIO_FORMAT_ERROR_MESSAGE,
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
),
"Soundfile is either not in the correct format or is malformed": (
AUDIO_FORMAT_ERROR_MESSAGE,
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
),
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
Expand All @@ -27,35 +47,11 @@
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


def handle_pipeline_error(e: Exception) -> JSONResponse:
"""Handles exceptions raised during audio processing.
Args:
e: The exception raised during audio processing.
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"Audio processing error: {str(e)}") # Log the detailed error
if "Soundfile is either not in the correct format or is malformed" in str(
e
) or isinstance(e, AudioConversionError):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
error_message = "Unsupported audio format or malformed file."
else:
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
error_message = "Internal server error during audio processing."

return JSONResponse(
status_code=status_code,
content=http_error(error_message),
)


@router.post(
"/audio-to-text",
response_model=TextResponse,
Expand Down Expand Up @@ -89,25 +85,32 @@ async def audio_to_text(
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
content=http_error("Invalid bearer token."),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
f"{model_id}."
),
)

if file_exceeds_max_size(audio, 50 * 1024 * 1024):
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content=http_error("File size exceeds limit"),
content=http_error("File size exceeds limit."),
)

try:
return pipeline(audio=audio)
except Exception as e:
return handle_pipeline_error(e)
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"AudioToText pipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Audio-to-text pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
44 changes: 26 additions & 18 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import os
import random
from typing import Annotated
from typing import Annotated, Dict, Tuple, Union

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.utils import LoraLoadingError
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.routes.utils import (
HTTPError,
ImageResponse,
http_error,
image_to_data_url,
handle_pipeline_exception,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -20,6 +25,15 @@
logger = logging.getLogger(__name__)


# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"OutOfMemoryError": (
"Out of memory error. Try reducing input image resolution.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
Expand Down Expand Up @@ -144,15 +158,15 @@ async def image_to_image(
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
content=http_error("Invalid bearer token."),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
f"{model_id}."
),
)

Expand Down Expand Up @@ -180,23 +194,17 @@ async def image_to_image(
num_images_per_prompt=1,
num_inference_steps=num_inference_steps,
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
except LoraLoadingError as e:
logger.error(f"ImageToImagePipeline error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToImagePipeline error"),
logger.error(f"ImageToImagePipeline pipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Image-to-image pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)

# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
Expand Down
Loading

0 comments on commit 40fa0c2

Please sign in to comment.