From 238d9826c54c8da07ec52ab105891320bfa260f6 Mon Sep 17 00:00:00 2001 From: ncoop57 Date: Mon, 30 Sep 2024 17:13:52 -0500 Subject: [PATCH] Fix a few bugs --- 00_core.ipynb | 4 ++-- claudette/core.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/00_core.ipynb b/00_core.ipynb index 7fad079..db7e0c5 100644 --- a/00_core.ipynb +++ b/00_core.ipynb @@ -1863,7 +1863,7 @@ " **kwargs):\n", " \"Make a call to Claude.\"\n", " if tools: kwargs['tools'] = [get_schema(o) for o in listify(tools)]\n", - " if tool_choice and pr: kwargs['tool_choice'] = mk_tool_choice(tool_choice)\n", + " if tool_choice: kwargs['tool_choice'] = mk_tool_choice(tool_choice)\n", " msgs = self._precall(msgs, prefill, stop, kwargs)\n", " if stream: return self._stream(msgs, prefill=prefill, max_tokens=maxtok, system=sp, temperature=temp, **kwargs)\n", " res = self.c.messages.create(model=self.model, messages=msgs, max_tokens=maxtok, system=sp, temperature=temp, **kwargs)\n", @@ -1953,7 +1953,7 @@ " \"Return the value of all tool calls (generally used for structured outputs)\"\n", " res = self(msgs, **kwargs)\n", " if ns is None: ns=globals()\n", - " cts = getattr(r, 'content', [])\n", + " cts = getattr(res, 'content', [])\n", " tcs = [call_func(o, ns=ns, obj=obj) for o in cts if isinstance(o,ToolUseBlock)]\n", " return tcs" ] diff --git a/claudette/core.py b/claudette/core.py index 8283f41..e36b598 100644 --- a/claudette/core.py +++ b/claudette/core.py @@ -211,7 +211,7 @@ def __call__(self:Client, **kwargs): "Make a call to Claude." if tools: kwargs['tools'] = [get_schema(o) for o in listify(tools)] - if tool_choice and pr: kwargs['tool_choice'] = mk_tool_choice(tool_choice) + if tool_choice: kwargs['tool_choice'] = mk_tool_choice(tool_choice) msgs = self._precall(msgs, prefill, stop, kwargs) if stream: return self._stream(msgs, prefill=prefill, max_tokens=maxtok, system=sp, temperature=temp, **kwargs) res = self.c.messages.create(model=self.model, messages=msgs, max_tokens=maxtok, system=sp, temperature=temp, **kwargs) @@ -228,7 +228,7 @@ def structured(self:Client, "Return the value of all tool calls (generally used for structured outputs)" res = self(msgs, **kwargs) if ns is None: ns=globals() - cts = getattr(r, 'content', []) + cts = getattr(res, 'content', []) tcs = [call_func(o, ns=ns, obj=obj) for o in cts if isinstance(o,ToolUseBlock)] return tcs