Skip to content

Commit

Permalink
bugfix: fix passing Image type in messages for chat (#390)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Aarni Koskela <[email protected]>
  • Loading branch information
ParthSareen and akx authored Dec 29, 2024
1 parent 7d1e002 commit ee349ec
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 124 deletions.
11 changes: 8 additions & 3 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def generate(
stream=stream,
raw=raw,
format=format,
images=[Image(value=image) for image in images] if images else None,
images=[image for image in _copy_images(images)] if images else None,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
Expand Down Expand Up @@ -753,7 +753,7 @@ async def generate(
stream=stream,
raw=raw,
format=format,
images=[Image(value=image) for image in images] if images else None,
images=[image for image in _copy_images(images)] if images else None,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
Expand Down Expand Up @@ -1121,10 +1121,15 @@ async def ps(self) -> ProcessResponse:
)


def _copy_images(images: Optional[Sequence[Union[Image, Any]]]) -> Iterator[Image]:
for image in images or []:
yield image if isinstance(image, Image) else Image(value=image)


def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]]]) -> Iterator[Message]:
for message in messages or []:
yield Message.model_validate(
{k: [Image(value=image) for image in v] if k == 'images' else v for k, v in dict(message).items() if v},
{k: [image for image in _copy_images(v)] if k == 'images' else v for k, v in dict(message).items() if v},
)


Expand Down
101 changes: 2 additions & 99 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ pytest = ">=7.4.3,<9.0.0"
pytest-asyncio = ">=0.23.2,<0.25.0"
pytest-cov = ">=4.1,<6.0"
pytest-httpserver = "^1.0.8"
pillow = "^10.2.0"
ruff = ">=0.1.8,<0.8.0"

[build-system]
Expand Down
Loading

0 comments on commit ee349ec

Please sign in to comment.