Skip to content

Commit

Permalink
Merge pull request #12 from xl0/stop_reason
Browse files Browse the repository at this point in the history
Better support for stop sequences
  • Loading branch information
jph00 authored Jul 30, 2024
2 parents 854e256 + 959fcbc commit c65e4ec
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 45 deletions.
85 changes: 82 additions & 3 deletions 00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "be54737f",
"id": "8015d3f3",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -888,6 +888,8 @@
" blk.text = prefill + (blk.text or '')\n",
" self.result = r\n",
" self.use += r.usage\n",
" self.stop_reason = r.stop_reason\n",
" self.stop_sequence = r.stop_sequence\n",
" return r"
]
},
Expand Down Expand Up @@ -958,7 +960,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5fb96c43",
"id": "835638bb",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -972,10 +974,14 @@
" maxtok=4096, # Maximum tokens\n",
" prefill='', # Optional prefill to pass to Claude as start of its response\n",
" stream:bool=False, # Stream response?\n",
" stop=None, # Stop sequence\n",
" **kwargs):\n",
" \"Make a call to Claude.\"\n",
" pref = [prefill.strip()] if prefill else []\n",
" if not isinstance(msgs,list): msgs = [msgs]\n",
" if stop is not None:\n",
" if not isinstance(stop, (list)): stop = [stop]\n",
" kwargs[\"stop_sequences\"] = stop\n",
" msgs = mk_msgs(msgs+pref)\n",
" if stream: return self._stream(msgs, prefill=prefill, max_tokens=maxtok, system=sp, temperature=temp, **kwargs)\n",
" res = self.c.messages.create(\n",
Expand Down Expand Up @@ -1224,6 +1230,79 @@
"c.use"
]
},
{
"cell_type": "markdown",
"id": "bd556e49",
"metadata": {},
"source": [
"Pass a stop seauence if you want claude to stop generating text when it encounters it.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12b17591",
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"1, 2, 3, 4,\n",
"\n",
"<details>\n",
"\n",
"- id: msg_01YGeu3qQNoNwafG66kB8V51\n",
"- content: [{'text': '1, 2, 3, 4, ', 'type': 'text'}]\n",
"- model: claude-3-haiku-20240307\n",
"- role: assistant\n",
"- stop_reason: stop_sequence\n",
"- stop_sequence: 5\n",
"- type: message\n",
"- usage: {'input_tokens': 15, 'output_tokens': 14}\n",
"\n",
"</details>"
],
"text/plain": [
"Message(id='msg_01YGeu3qQNoNwafG66kB8V51', content=[TextBlock(text='1, 2, 3, 4, ', type='text')], model='claude-3-haiku-20240307', role='assistant', stop_reason='stop_sequence', stop_sequence='5', type='message', usage=In: 15; Out: 14; Total: 29)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c(\"Count from 1 to 10\", stop=\"5\")"
]
},
{
"cell_type": "markdown",
"id": "cdbd3df3",
"metadata": {},
"source": [
"This also works with streaming, and you can pass more than one stop sequence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff50577d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1, stop_sequence 2\n"
]
}
],
"source": [
"for o in c(\"Count from 1 to 10\", stop=[\"2\", \"yellow\"], stream=True):\n",
" print(o, end='')\n",
"print(c.stop_reason, c.stop_sequence)"
]
},
{
"cell_type": "markdown",
"id": "1a7cdbc6",
Expand Down Expand Up @@ -1442,7 +1521,7 @@
" func = getattr(obj, fc.name, None)\n",
" if not func: func = ns[fc.name]\n",
" res = func(**fc.input)\n",
" return dict(type=\"tool_result\", tool_use_id=fc.id, content=str(res)) "
" return dict(type=\"tool_result\", tool_use_id=fc.id, content=str(res))"
]
},
{
Expand Down
26 changes: 0 additions & 26 deletions claudette/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,4 @@
'claudette.core.mk_toolres': ('core.html#mk_toolres', 'claudette/core.py'),
'claudette.core.text_msg': ('core.html#text_msg', 'claudette/core.py'),
'claudette.core.usage': ('core.html#usage', 'claudette/core.py')},
'claudette.core2': { 'claudette.core2.Chat': ('core.html#chat', 'claudette/core2.py'),
'claudette.core2.Chat.__call__': ('core.html#chat.__call__', 'claudette/core2.py'),
'claudette.core2.Chat.__init__': ('core.html#chat.__init__', 'claudette/core2.py'),
'claudette.core2.Chat._stream': ('core.html#chat._stream', 'claudette/core2.py'),
'claudette.core2.Chat.use': ('core.html#chat.use', 'claudette/core2.py'),
'claudette.core2.Client': ('core.html#client', 'claudette/core2.py'),
'claudette.core2.Client.__call__': ('core.html#client.__call__', 'claudette/core2.py'),
'claudette.core2.Client.__init__': ('core.html#client.__init__', 'claudette/core2.py'),
'claudette.core2.Client._r': ('core.html#client._r', 'claudette/core2.py'),
'claudette.core2.Client._stream': ('core.html#client._stream', 'claudette/core2.py'),
'claudette.core2.Message._repr_markdown_': ('core.html#message._repr_markdown_', 'claudette/core2.py'),
'claudette.core2.Usage.__add__': ('core.html#usage.__add__', 'claudette/core2.py'),
'claudette.core2.Usage.__repr__': ('core.html#usage.__repr__', 'claudette/core2.py'),
'claudette.core2.Usage.total': ('core.html#usage.total', 'claudette/core2.py'),
'claudette.core2._mk_content': ('core.html#_mk_content', 'claudette/core2.py'),
'claudette.core2._mk_ns': ('core.html#_mk_ns', 'claudette/core2.py'),
'claudette.core2.call_func': ('core.html#call_func', 'claudette/core2.py'),
'claudette.core2.contents': ('core.html#contents', 'claudette/core2.py'),
'claudette.core2.find_block': ('core.html#find_block', 'claudette/core2.py'),
'claudette.core2.img_msg': ('core.html#img_msg', 'claudette/core2.py'),
'claudette.core2.mk_msg': ('core.html#mk_msg', 'claudette/core2.py'),
'claudette.core2.mk_msgs': ('core.html#mk_msgs', 'claudette/core2.py'),
'claudette.core2.mk_tool_choice': ('core.html#mk_tool_choice', 'claudette/core2.py'),
'claudette.core2.mk_toolres': ('core.html#mk_toolres', 'claudette/core2.py'),
'claudette.core2.text_msg': ('core.html#text_msg', 'claudette/core2.py'),
'claudette.core2.usage': ('core.html#usage', 'claudette/core2.py')},
'claudette.toolloop': {'claudette.toolloop.Chat.toolloop': ('toolloop.html#chat.toolloop', 'claudette/toolloop.py')}}}
34 changes: 20 additions & 14 deletions claudette/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def _r(self:Client, r:Message, prefill=''):
blk.text = prefill + (blk.text or '')
self.result = r
self.use += r.usage
self.stop_reason = r.stop_reason
self.stop_sequence = r.stop_sequence
return r

# %% ../00_core.ipynb 66
Expand All @@ -116,28 +118,32 @@ def __call__(self:Client,
maxtok=4096, # Maximum tokens
prefill='', # Optional prefill to pass to Claude as start of its response
stream:bool=False, # Stream response?
stop=None, # Stop sequence
**kwargs):
"Make a call to Claude."
pref = [prefill.strip()] if prefill else []
if not isinstance(msgs,list): msgs = [msgs]
if stop is not None:
if not isinstance(stop, (list)): stop = [stop]
kwargs["stop_sequences"] = stop
msgs = mk_msgs(msgs+pref)
if stream: return self._stream(msgs, prefill=prefill, max_tokens=maxtok, system=sp, temperature=temp, **kwargs)
res = self.c.messages.create(
model=self.model, messages=msgs, max_tokens=maxtok, system=sp, temperature=temp, **kwargs)
self._r(res, prefill)
return self.result

# %% ../00_core.ipynb 84
# %% ../00_core.ipynb 88
def mk_tool_choice(choose:Union[str,bool,None])->dict:
"Create a `tool_choice` dict that's 'auto' if `choose` is `None`, 'any' if it is True, or 'tool' otherwise"
return {"type": "tool", "name": choose} if isinstance(choose,str) else {'type':'any'} if choose else {'type':'auto'}

# %% ../00_core.ipynb 94
# %% ../00_core.ipynb 98
def _mk_ns(*funcs:list[callable]) -> dict[str,callable]:
"Create a `dict` of name to function in `funcs`, to use as a namespace"
return {f.__name__:f for f in funcs}

# %% ../00_core.ipynb 96
# %% ../00_core.ipynb 100
def call_func(fc:ToolUseBlock, # Tool use block from Claude's message
ns:Optional[abc.Mapping]=None, # Namespace to search for tools, defaults to `globals()`
obj:Optional=None # Object to search for tools
Expand All @@ -148,9 +154,9 @@ def call_func(fc:ToolUseBlock, # Tool use block from Claude's message
func = getattr(obj, fc.name, None)
if not func: func = ns[fc.name]
res = func(**fc.input)
return dict(type="tool_result", tool_use_id=fc.id, content=str(res))
return dict(type="tool_result", tool_use_id=fc.id, content=str(res))

# %% ../00_core.ipynb 99
# %% ../00_core.ipynb 103
def mk_toolres(
r:abc.Mapping, # Tool use request response from Claude
ns:Optional[abc.Mapping]=None, # Namespace to search for tools
Expand All @@ -163,7 +169,7 @@ def mk_toolres(
if tcs: res.append(mk_msg(tcs))
return res

# %% ../00_core.ipynb 109
# %% ../00_core.ipynb 113
class Chat:
def __init__(self,
model:Optional[str]=None, # Model to use (leave empty if passing `cli`)
Expand All @@ -179,13 +185,13 @@ def __init__(self,
@property
def use(self): return self.c.use

# %% ../00_core.ipynb 112
# %% ../00_core.ipynb 116
@patch
def _stream(self:Chat, res):
yield from res
self.h += mk_toolres(self.c.result, ns=self.tools, obj=self)

# %% ../00_core.ipynb 113
# %% ../00_core.ipynb 117
@patch
def __call__(self:Chat,
pr=None, # Prompt / message
Expand All @@ -204,27 +210,27 @@ def __call__(self:Chat,
self.h += mk_toolres(self.c.result, ns=self.tools, obj=self)
return res

# %% ../00_core.ipynb 132
# %% ../00_core.ipynb 136
def img_msg(data:bytes)->dict:
"Convert image `data` into an encoded `dict`"
img = base64.b64encode(data).decode("utf-8")
mtype = mimetypes.types_map['.'+imghdr.what(None, h=data)]
r = dict(type="base64", media_type=mtype, data=img)
return {"type": "image", "source": r}

# %% ../00_core.ipynb 134
# %% ../00_core.ipynb 138
def text_msg(s:str)->dict:
"Convert `s` to a text message"
return {"type": "text", "text": s}

# %% ../00_core.ipynb 138
# %% ../00_core.ipynb 142
def _mk_content(src):
"Create appropriate content data structure based on type of content"
if isinstance(src,str): return text_msg(src)
if isinstance(src,bytes): return img_msg(src)
return src

# %% ../00_core.ipynb 141
# %% ../00_core.ipynb 145
def mk_msg(content, # A string, list, or dict containing the contents of the message
role='user', # Must be 'user' or 'assistant'
**kw):
Expand All @@ -235,9 +241,9 @@ def mk_msg(content, # A string, list, or dict containing the contents of the mes
content = [_mk_content(o) for o in content] if content else '.'
return dict(role=role, content=content, **kw)

# %% ../00_core.ipynb 148
# %% ../00_core.ipynb 152
models_aws = ('anthropic.claude-3-haiku-20240307-v1:0', 'anthropic.claude-3-sonnet-20240229-v1:0',
'anthropic.claude-3-opus-20240229-v1:0', 'anthropic.claude-3-5-sonnet-20240620-v1:0')

# %% ../00_core.ipynb 154
# %% ../00_core.ipynb 158
models_goog = 'claude-3-haiku@20240307', 'claude-3-sonnet@20240229', 'claude-3-opus@20240229', 'claude-3-5-sonnet@20240620'
4 changes: 2 additions & 2 deletions settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ language = English
status = 3
user = AnswerDotAI
readme_nb = index.ipynb
allowed_metadata_keys =
allowed_cell_metadata_keys =
allowed_metadata_keys =
allowed_cell_metadata_keys =
jupyter_hooks = True
clean_ids = True
clear_all = False
Expand Down

0 comments on commit c65e4ec

Please sign in to comment.