|
| 1 | +import base64 |
1 | 2 | from abc import ABC
|
2 | 3 | from typing import Optional, Union, List, Literal, Type, Any
|
3 | 4 |
|
4 | 5 | from docstring_parser import parse
|
5 | 6 | from json_repair import repair_json
|
6 | 7 | from pydantic import ConfigDict, BaseModel, Field, field_validator, model_validator
|
7 | 8 |
|
| 9 | +from llmkira.openai.utils import resize_openai_image |
| 10 | + |
8 | 11 |
|
9 | 12 | class FunctionChoice(BaseModel):
|
10 | 13 | name: str
|
@@ -170,11 +173,91 @@ class SystemMessage(Message):
|
170 | 173 | name: Optional[str] = None
|
171 | 174 |
|
172 | 175 |
|
| 176 | +class ImageContent(BaseModel): |
| 177 | + url: str |
| 178 | + detail: Optional[str] = "auto" |
| 179 | + |
| 180 | + |
| 181 | +class ContentPart(BaseModel): |
| 182 | + type: Union[str, Literal["text", "image_url"]] |
| 183 | + text: Optional[str] = None |
| 184 | + image_url: Optional[ImageContent] = None |
| 185 | + |
| 186 | + @model_validator(mode="after") |
| 187 | + def check_model(self): |
| 188 | + if self.type == "image_url": |
| 189 | + if self.image_url is None: |
| 190 | + raise ValueError("image_url cannot be None") |
| 191 | + if self.type == "text": |
| 192 | + if self.text is None: |
| 193 | + raise ValueError("text cannot be None") |
| 194 | + return self |
| 195 | + |
| 196 | + @classmethod |
| 197 | + def create_text(cls, text: str): |
| 198 | + """ |
| 199 | + Create a text content part |
| 200 | + :param text: text |
| 201 | + :return: ContentPart |
| 202 | + """ |
| 203 | + assert isinstance(text, str), ValueError("text must be a string") |
| 204 | + return cls(type="text", text=text) |
| 205 | + |
| 206 | + @classmethod |
| 207 | + def create_image( |
| 208 | + cls, url: Union[str, bytes], detail: Literal["low", "high", "auto"] = "auto" |
| 209 | + ): |
| 210 | + """ |
| 211 | + Create an image content part |
| 212 | + :param url: image url or image bytes |
| 213 | + :param detail: image detail |
| 214 | + :return: ContentPart |
| 215 | + """ |
| 216 | + assert detail in ("low", "high", "auto"), ValueError( |
| 217 | + "detail must be low, high or auto" |
| 218 | + ) |
| 219 | + if isinstance(url, bytes): |
| 220 | + url = resize_openai_image(url, mode=detail) |
| 221 | + base64_image = base64.b64encode(url).decode("utf-8") |
| 222 | + url = f"data:image/jpeg;base64,{base64_image}" |
| 223 | + elif isinstance(url, str): |
| 224 | + if not url.startswith("http") or not url.startswith( |
| 225 | + "data:image/jpeg;base64," |
| 226 | + ): |
| 227 | + raise ValueError( |
| 228 | + "url must be a http url or `data:image/jpeg;base64,` as base64 image" |
| 229 | + ) |
| 230 | + else: |
| 231 | + raise ValueError("url must be a http url or bytes") |
| 232 | + return cls(type="image_url", image_url=ImageContent(url=url, detail=detail)) |
| 233 | + |
| 234 | + |
173 | 235 | class UserMessage(Message):
|
174 | 236 | role: Literal["user"] = "user"
|
175 |
| - content: str |
| 237 | + content: Union[str, List[ContentPart]] |
176 | 238 | name: Optional[str] = None
|
177 | 239 |
|
| 240 | + @field_validator("content") |
| 241 | + def check_content(cls, v): |
| 242 | + if isinstance(v, str): |
| 243 | + return [ContentPart.create_text(text=v)] |
| 244 | + elif isinstance(v, list): |
| 245 | + return v |
| 246 | + else: |
| 247 | + raise ValueError("content must be a string or a list of ContentPart") |
| 248 | + |
| 249 | + def add_text(self, text: str): |
| 250 | + self.content.append(ContentPart.create_text(text=text)) |
| 251 | + return self |
| 252 | + |
| 253 | + def add_image( |
| 254 | + self, |
| 255 | + image_url: Union[str, bytes], |
| 256 | + detail: Literal["low", "high", "auto"] = "auto", |
| 257 | + ): |
| 258 | + self.content.append(ContentPart.create_image(url=image_url, detail=detail)) |
| 259 | + return self |
| 260 | + |
178 | 261 |
|
179 | 262 | class ToolMessage(Message):
|
180 | 263 | role: Literal["tool"] = "tool"
|
|
0 commit comments