Skip to content

Commit

Permalink
handle openai sdk objects + fix _mk_img bug
Browse files Browse the repository at this point in the history
  • Loading branch information
comhar committed Oct 15, 2024
1 parent efedcb3 commit 9d8d3f6
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 25 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

### Installation

Install the latest version from the GitHub
[repository](https://github.com/AnswerDotAI/msglm):
Install the latest version from pypi

``` sh
$ pip install git+ssh://[email protected]/AnswerDotAI/msglm.git
$ pip install msglm
```

## Usage
Expand Down
3 changes: 3 additions & 0 deletions msglm/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
'syms': { 'msglm.core': { 'msglm.core.AnthropicMsg': ('core.html#anthropicmsg', 'msglm/core.py'),
'msglm.core.AnthropicMsg.find_block': ('core.html#anthropicmsg.find_block', 'msglm/core.py'),
'msglm.core.AnthropicMsg.img_msg': ('core.html#anthropicmsg.img_msg', 'msglm/core.py'),
'msglm.core.AnthropicMsg.is_sdk_obj': ('core.html#anthropicmsg.is_sdk_obj', 'msglm/core.py'),
'msglm.core.Msg': ('core.html#msg', 'msglm/core.py'),
'msglm.core.Msg.__call__': ('core.html#msg.__call__', 'msglm/core.py'),
'msglm.core.Msg.find_block': ('core.html#msg.find_block', 'msglm/core.py'),
'msglm.core.Msg.img_msg': ('core.html#msg.img_msg', 'msglm/core.py'),
'msglm.core.Msg.is_sdk_obj': ('core.html#msg.is_sdk_obj', 'msglm/core.py'),
'msglm.core.Msg.mk_content': ('core.html#msg.mk_content', 'msglm/core.py'),
'msglm.core.Msg.text_msg': ('core.html#msg.text_msg', 'msglm/core.py'),
'msglm.core.OpenAiMsg': ('core.html#openaimsg', 'msglm/core.py'),
'msglm.core.OpenAiMsg.find_block': ('core.html#openaimsg.find_block', 'msglm/core.py'),
'msglm.core.OpenAiMsg.img_msg': ('core.html#openaimsg.img_msg', 'msglm/core.py'),
'msglm.core.OpenAiMsg.is_sdk_obj': ('core.html#openaimsg.is_sdk_obj', 'msglm/core.py'),
'msglm.core._add_cache_control': ('core.html#_add_cache_control', 'msglm/core.py'),
'msglm.core._mk_img': ('core.html#_mk_img', 'msglm/core.py'),
'msglm.core.mk_msg': ('core.html#mk_msg', 'msglm/core.py'),
Expand Down
26 changes: 21 additions & 5 deletions msglm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,20 @@ def mk_msgs(msgs: list, *args, api:str="openai", **kw) -> list:
# %% ../nbs/00_core.ipynb
class Msg:
"Helper class to create a message for the OpenAI and Anthropic APIs."
sdk_obj_support=False # is an SDK object a valid message?
def __call__(self, role:str, content:[list, str], text_only:bool=False, **kw)->dict:
"Create an OpenAI/Anthropic compatible message with `role` and `content`."
if self.sdk_obj_support and self.is_sdk_obj(content): return self.find_block(content)
if hasattr(content, "content"): content, role = content.content, content.role
content = self.find_block(content)
if content is not None and not isinstance(content, list): content = [content]
content = [self.mk_content(o, text_only=text_only) for o in content] if content else ''
return dict(role=role, content=content[0] if text_only else content, **kw)

def is_sdk_obj(self, r)-> bool:
"Check if `r` is an SDK object."
raise NotImplemented

def find_block(self, r)->dict:
"Find the message in `r`."
raise NotImplemented
Expand All @@ -66,30 +72,40 @@ def mk_content(self, content:[str, bytes], text_only:bool=False) -> dict:

# %% ../nbs/00_core.ipynb
class AnthropicMsg(Msg):
sdk_obj_support=False
def img_msg(self, data: bytes) -> dict:
"Convert `data` to an image message"
img, mtype = mk_img(data)
img, mtype = _mk_img(data)
r = {"type": "base64", "media_type": mtype, "data":img}
return {"type": "image", "source": r}

def is_sdk_obj(self, r)-> bool:
"Check if `r` is an SDK object."
return isinstance(r, abc.Mapping)

def find_block(self, r):
"Find the message in `r`."
return r.get('content', r) if isinstance(r, abc.Mapping) else r
return r.get('content', r) if self.is_sdk_obj(r) else r

# %% ../nbs/00_core.ipynb
class OpenAiMsg(Msg):
sdk_obj_support=True
def img_msg(self, data: bytes) -> dict:
"Convert `data` to an image message"
img, mtype = mk_img(data)
img, mtype = _mk_img(data)
r = {"url": f"data:{mtype};base64,{img}"}
return {"type": "image_url", "image_url": r}

def is_sdk_obj(self, r)-> bool:
"Check if `r` is an SDK object."
return type(r).__module__ != "builtins"

def find_block(self, r):
"Find the message in `r`."
if type(r).__module__ == "builtins": return r
if not self.is_sdk_obj(r): return r
m = nested_idx(r, "choices", 0)
if not m: return m
if hasattr(m, "message"): return m.message.content
if hasattr(m, "message"): return m.message
return m.delta

# %% ../nbs/00_core.ipynb
Expand Down
36 changes: 27 additions & 9 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@
" \"Helper class to create a message for the OpenAI API.\"\n",
" def img_msg(self, data)->dict:\n",
" \"Convert `data` to an image message\"\n",
" img, mtype = mk_img(data)\n",
" img, mtype = _mk_img(data)\n",
" r = {\"url\": f\"data:{mtype};base64,{img}\"}\n",
" return {\"type\": \"image_url\", \"image_url\": r}"
]
Expand All @@ -910,7 +910,7 @@
" \"Helper class to create a message for the Anthropic API.\"\n",
" def img_msg(self, data)->dict:\n",
" \"Convert `data` to an image message\"\n",
" img, mtype = mk_img(data)\n",
" img, mtype = _mk_img(data)\n",
" r = dict(type='base64', media_type=mtype, data=img)\n",
" return {\"type\": \"image\", \"source\": r}"
]
Expand Down Expand Up @@ -1027,14 +1027,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### API Response"
"### SDK Objects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make our lives even easier, it would be nice if `mk_msg` could format the raw API response from a previous API call so that we can pass it straight to `mk_msgs`."
"To make our lives even easier, it would be nice if `mk_msg` could format the SDK objects returned from a previous chat so that we can pass them straight to `mk_msgs`.\n",
"\n",
"The OpenAI SDK accepts objects like `ChatCompletion` as messages. Anthropic is different and expects every message to have the `role`, `content` format that we've seen so far."
]
},
{
Expand All @@ -1046,14 +1048,20 @@
"#| export\n",
"class Msg:\n",
" \"Helper class to create a message for the OpenAI and Anthropic APIs.\"\n",
" sdk_obj_support=False # is an SDK object a valid message?\n",
" def __call__(self, role:str, content:[list, str], text_only:bool=False, **kw)->dict:\n",
" \"Create an OpenAI/Anthropic compatible message with `role` and `content`.\"\n",
" if self.sdk_obj_support and self.is_sdk_obj(content): return self.find_block(content)\n",
" if hasattr(content, \"content\"): content, role = content.content, content.role\n",
" content = self.find_block(content)\n",
" if content is not None and not isinstance(content, list): content = [content]\n",
" content = [self.mk_content(o, text_only=text_only) for o in content] if content else ''\n",
" return dict(role=role, content=content[0] if text_only else content, **kw)\n",
"\n",
" def is_sdk_obj(self, r)-> bool:\n",
" \"Check if `r` is an SDK object.\"\n",
" raise NotImplemented\n",
" \n",
" def find_block(self, r)->dict:\n",
" \"Find the message in `r`.\"\n",
" raise NotImplemented\n",
Expand Down Expand Up @@ -1081,15 +1089,20 @@
"source": [
"#| export\n",
"class AnthropicMsg(Msg):\n",
" sdk_obj_support=False\n",
" def img_msg(self, data: bytes) -> dict:\n",
" \"Convert `data` to an image message\"\n",
" img, mtype = mk_img(data)\n",
" img, mtype = _mk_img(data)\n",
" r = {\"type\": \"base64\", \"media_type\": mtype, \"data\":img}\n",
" return {\"type\": \"image\", \"source\": r}\n",
"\n",
" def is_sdk_obj(self, r)-> bool:\n",
" \"Check if `r` is an SDK object.\"\n",
" return isinstance(r, abc.Mapping)\n",
"\n",
" def find_block(self, r):\n",
" \"Find the message in `r`.\"\n",
" return r.get('content', r) if isinstance(r, abc.Mapping) else r"
" return r.get('content', r) if self.is_sdk_obj(r) else r"
]
},
{
Expand All @@ -1100,18 +1113,23 @@
"source": [
"#| export\n",
"class OpenAiMsg(Msg):\n",
" sdk_obj_support=True\n",
" def img_msg(self, data: bytes) -> dict:\n",
" \"Convert `data` to an image message\"\n",
" img, mtype = mk_img(data)\n",
" img, mtype = _mk_img(data)\n",
" r = {\"url\": f\"data:{mtype};base64,{img}\"}\n",
" return {\"type\": \"image_url\", \"image_url\": r}\n",
"\n",
" def is_sdk_obj(self, r)-> bool:\n",
" \"Check if `r` is an SDK object.\"\n",
" return type(r).__module__ != \"builtins\"\n",
"\n",
" def find_block(self, r):\n",
" \"Find the message in `r`.\"\n",
" if type(r).__module__ == \"builtins\": return r\n",
" if not self.is_sdk_obj(r): return r\n",
" m = nested_idx(r, \"choices\", 0)\n",
" if not m: return m\n",
" if hasattr(m, \"message\"): return m.message.content\n",
" if hasattr(m, \"message\"): return m.message\n",
" return m.delta"
]
},
Expand Down
12 changes: 4 additions & 8 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation"
]
"source": "### Installation"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the latest version from the GitHub [repository][repo]:\n",
"Install the latest version from pypi\n",
"\n",
"```sh\n",
"$ pip install git+ssh://[email protected]/AnswerDotAI/msglm.git\n",
"```\n",
"\n",
"[repo]: https://github.com/AnswerDotAI/msglm"
"$ pip install msglm\n",
"```\n"
]
},
{
Expand Down

0 comments on commit 9d8d3f6

Please sign in to comment.