Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: gemini: improve error messages and PDF validation #187

Merged
merged 4 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions aidial_adapter_vertexai/chat/attachment_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from aidial_adapter_vertexai.chat.conversation.factory import (
ConversationFactoryBase,
)
from aidial_adapter_vertexai.chat.errors import ValidationError
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.gemini.conversation_factory import PartT
from aidial_adapter_vertexai.dial_api.request import get_attachments
from aidial_adapter_vertexai.dial_api.resource import (
Expand Down Expand Up @@ -92,11 +92,18 @@ async def process(

except Exception as e:
log.error(
f"Failed to download {dial_resource.entity_name}: {str(e)}"
f"Failed to process {dial_resource.entity_name}: {str(e)}"
)
if isinstance(e, ResourceValidationError):
# Errors specific to a particular resource
return e.message
return f"Failed to download {dial_resource.entity_name}"
elif isinstance(e, UserError):
# Errors not specific to any particular resource
# typically raised by validators
raise e
else:
# Unexpected runtime exceptions
return f"Failed to process {dial_resource.entity_name}"


@dataclass(order=True, frozen=True)
Expand Down Expand Up @@ -153,7 +160,7 @@ async def process_resource(
self, dial_resource: DialResource
) -> Resource | None:
if not self.processors:
raise ValidationError("The attachments aren't supported")
raise UserError("The attachments aren't supported")

for processor in self.processors:
resource = await processor.process(self.file_storage, dial_resource)
Expand Down Expand Up @@ -218,36 +225,47 @@ class AttachmentProcessorsGenAI(AttachmentProcessorsBase[GenAIPart]):
pass


def max_count_validator(limit: int) -> InitValidator:
def max_count_validator(category: str, limit: int) -> InitValidator:
count = 0

async def validator():
nonlocal count
count += 1
if count > limit:
raise ValidationError(
f"The number of files exceeds the limit ({limit})"
raise UserError(
f"The number of {category} files exceeds the limit ({limit})"
)

return validator


def max_pdf_page_count_validator(limit: int) -> PostValidator:
count = 0
def max_pdf_page_count_validator(
limit_per_request: int | None,
limit_per_document: int | None,
) -> PostValidator:
total_pages = 0

async def validator(resource: Resource):
nonlocal count
if limit_per_document is None and limit_per_request is None:
return

nonlocal total_pages
try:
pages = await get_pdf_page_count(resource.data)
log.debug(f"PDF page count: {pages}")
count += pages
total_pages += pages
except Exception:
log.exception("Failed to get PDF page count")
raise ValidationError("Failed to get PDF page count")
raise ResourceValidationError("Failed to get PDF page count")

if count > limit:
raise ValidationError(
f"The total number of PDF pages exceeds the limit ({limit})"
if limit_per_document is not None and pages > limit_per_document:
raise ResourceValidationError(
f"The number of pages in the document ({pages}) exceeds the limit ({limit_per_document})"
)

if limit_per_request is not None and total_pages > limit_per_request:
raise UserError(
f"The total number of pages in PDF documents exceeds the limit ({limit_per_request})"
)

return validator
Expand All @@ -274,9 +292,8 @@ async def validator():
if first is None:
first = name
elif first != name:
raise ValidationError(
f"The document type is {name!r}. "
f"However, one of the documents processed earlier was of {first!r} type. "
raise UserError(
f"Found documents of types {name!r} and {first!r}. "
"Only one type of document is supported at a time."
)

Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_vertexai/chat/claude/prompt/claude_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_image_processor(max_count: int) -> AttachmentProcessor:
# NOTE: not checked condition: The maximum allowed image file size is 5 MB
return AttachmentProcessor(
file_types=SUPPORTED_IMAGE_TYPES,
init_validator=max_count_validator(max_count),
init_validator=max_count_validator("image", max_count),
)


Expand Down
16 changes: 11 additions & 5 deletions aidial_adapter_vertexai/chat/gemini/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_image_processor(
"image/heif": "heif",
},
init_validator=seq_validators(
init_validator, max_count_validator(max_count)
init_validator, max_count_validator("image", max_count)
),
)

Expand Down Expand Up @@ -91,16 +91,22 @@ def get_audio_processor(

# PDF processing
# 1.0: max number of PDF pages: 16
# 1.5: max number of PDF pages: 300
# 1.5: max number of PDF pages: 1000
# The maximum file size for a PDF is 50MB (not checked).
# PDF pages are treated as individual images.
def get_pdf_processor(
max_page_count: int, init_validator: InitValidator | None = None
*,
page_limit_per_document: int | None = None,
page_limit_per_request: int | None = None,
init_validator: InitValidator | None = None,
) -> AttachmentProcessor:
return AttachmentProcessor(
file_types={"application/pdf": "pdf"},
init_validator=init_validator,
post_validator=max_pdf_page_count_validator(max_page_count),
post_validator=max_pdf_page_count_validator(
limit_per_request=page_limit_per_request,
limit_per_document=page_limit_per_document,
),
)


Expand Down Expand Up @@ -134,6 +140,6 @@ def get_video_processor(
"video/3gpp": "3gpp",
},
init_validator=seq_validators(
init_validator, max_count_validator(max_count)
init_validator, max_count_validator("video", max_count)
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ async def parse(
processors=[
get_plain_text_processor(),
get_image_processor(16, exclusive("image")),
get_pdf_processor(16, exclusive("pdf")),
get_pdf_processor(
page_limit_per_request=16,
page_limit_per_document=16,
init_validator=exclusive("pdf"),
),
get_video_processor(1, exclusive("video")),
],
file_storage=file_storage,
Expand Down
5 changes: 4 additions & 1 deletion aidial_adapter_vertexai/chat/gemini/prompt/gemini_1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ async def parse(
processors=[
get_plain_text_processor(),
get_image_processor(3000),
get_pdf_processor(300),
get_pdf_processor(
page_limit_per_request=3000,
page_limit_per_document=1000,
),
get_video_processor(10),
get_audio_processor(),
],
Expand Down
5 changes: 4 additions & 1 deletion aidial_adapter_vertexai/chat/gemini/prompt/gemini_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ async def parse(
processors=[
get_plain_text_processor(),
get_image_processor(3000),
get_pdf_processor(300),
get_pdf_processor(
page_limit_per_request=3000,
page_limit_per_document=1000,
),
get_video_processor(10),
get_audio_processor(),
],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ nox = "^2023.4.22"
google-auth-oauthlib = "1.0.0"

[tool.pytest.ini_options]
addopts = "-n=auto --asyncio-mode=auto"
addopts = "-n=auto --asyncio-mode=auto --self-contained-html --html=pytest.html"
env_override_existing_values = 1
filterwarnings = [
# muting warnings from google.rpc https://github.com/googleapis/google-cloud-python/issues/11184 and opentelemetry packages
Expand Down
49 changes: 49 additions & 0 deletions tests/integration_tests/test_chat_completion_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from openai.types.chat import ChatCompletionMessageParam

from aidial_adapter_vertexai.deployments import ChatCompletionDeployment
from aidial_adapter_vertexai.utils.resource import Resource
from tests.utils.openai import (
ChatCompletionResult,
ai,
chat_completion,
sanitize_test_name,
sys,
user,
user_with_attachment_data,
)
from tests.utils.pdf import gen_pdf

deployments = [
ChatCompletionDeployment.CHAT_BISON_1,
Expand Down Expand Up @@ -167,3 +170,49 @@ async def test_imagen_content_filtering(get_openai_client):
resp["error"]["message"]
== "The response is blocked, as it may violate our policies."
)


async def test_gemini_pdf_page_overflow_for_document(get_openai_client):
client = get_openai_client(ChatCompletionDeployment.GEMINI_PRO_1_5_V2.value)

doc = Resource(type="application/pdf", data=gen_pdf(["a"] * 2_000))

messages: List[ChatCompletionMessageParam] = [
user_with_attachment_data("test", doc)
]

with pytest.raises(Exception) as exc_info:
await chat_completion(
client, messages, False, None, None, None, None, None, None
)

assert isinstance(exc_info.value, UnprocessableEntityError)

error = exc_info.value.response.json()["error"]
expected_message = "The following files failed to process:\n1. data attachment: the number of pages in the document (2000) exceeds the limit (1000)"
assert error["message"] == expected_message
assert error["display_message"] == expected_message


async def test_gemini_pdf_page_overflow_for_request(get_openai_client):
client = get_openai_client(ChatCompletionDeployment.GEMINI_PRO_1_5_V2.value)

doc = Resource(type="application/pdf", data=gen_pdf(["a"] * 1_000))

messages: List[ChatCompletionMessageParam] = [
user_with_attachment_data("test", doc, doc, doc, doc)
]

with pytest.raises(Exception) as exc_info:
await chat_completion(
client, messages, False, None, None, None, None, None, None
)

assert isinstance(exc_info.value, UnprocessableEntityError)

error = exc_info.value.response.json()["error"]
expected_message = (
"The total number of pages in PDF documents exceeds the limit (3000)"
)
assert error["message"] == expected_message
assert error["display_message"] == expected_message
4 changes: 2 additions & 2 deletions tests/utils/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def user(


def user_with_attachment_data(
content: str, resource: Resource
content: str, *resource: Resource
) -> ChatCompletionUserMessageParam:
return {
"role": "user",
"content": content,
"custom_content": { # type: ignore
"attachments": [
{"type": resource.type, "data": resource.data_base64}
{"type": r.type, "data": r.data_base64} for r in resource
]
},
}
Expand Down
64 changes: 64 additions & 0 deletions tests/utils/pdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from io import BytesIO
from typing import List

from pypdf import PageObject, PdfWriter
from pypdf.generic import DecodedStreamObject, DictionaryObject, NameObject


def _escape_pdf_string(s: str) -> str:
return s.replace("\\", "\\\\").replace("(", "\\(").replace(")", "\\)")


def _create_page_with_text(
text: str, width: int = 612, height: int = 792
) -> PageObject:
page = PageObject.create_blank_page(width=width, height=height)

content_stream = "BT\n"
content_stream += "/F1 12 Tf\n"

x = 72
y = height - 72

for line in text.splitlines():
content_stream += f"{x} {y} Td\n"
content_stream += f"({_escape_pdf_string(line)}) Tj\n"
y -= 14
content_stream += "ET\n"

stream = DecodedStreamObject()
stream.set_data(content_stream.encode("utf-8"))
page[NameObject("/Contents")] = stream

resources = DictionaryObject()
font_dict = DictionaryObject()
font = DictionaryObject(
{
NameObject("/Type"): NameObject("/Font"),
NameObject("/Subtype"): NameObject("/Type1"),
NameObject("/BaseFont"): NameObject("/Helvetica"),
}
)
font_dict[NameObject("/F1")] = font
resources[NameObject("/Font")] = font_dict
page[NameObject("/Resources")] = resources

return page


def gen_pdf(pages: List[str]) -> bytes:
"""
Generate a PDF from a list of page strings.
"""
writer = PdfWriter()

for text in pages:
page = _create_page_with_text(text)
writer.add_page(page)

output = BytesIO()
writer.write(output)
pdf_bytes = output.getvalue()
output.close()

return pdf_bytes