diff --git a/01_toolloop.ipynb b/01_toolloop.ipynb index 63024cb..e5de910 100644 --- a/01_toolloop.ipynb +++ b/01_toolloop.ipynb @@ -277,13 +277,14 @@ " cont_func:Optional[callable]=noop, # Function that stops loop if returns False\n", " **kwargs):\n", " \"Add prompt `pr` to dialog and get a response from Claude, automatically following up with `tool_use` messages\"\n", + " n_msgs = len(self.h)\n", " r = self(pr, **kwargs)\n", " for i in range(max_steps):\n", " if r.stop_reason!='tool_use': break\n", - " if trace_func: trace_func(r)\n", + " if trace_func: trace_func(self.h[n_msgs:]); n_msgs = len(self.h)\n", " r = self(**kwargs)\n", " if not (cont_func or noop)(self.h[-2]): break\n", - " if trace_func: trace_func(r)\n", + " if trace_func: trace_func(self.h[n_msgs:])\n", " return r" ] }, @@ -631,12 +632,13 @@ "metadata": {}, "outputs": [], "source": [ - "def _show_cts(r):\n", - " for o in r.content:\n", - " if hasattr(o,'text'): print(o.text)\n", - " nm = getattr(o, 'name', None)\n", - " if nm=='run_cell': print(o.input['code'])\n", - " elif nm: print(f'{o.name}({o.input})')" + "def _show_cts(h):\n", + " for r in h:\n", + " for o in r.get('content'):\n", + " if hasattr(o,'text'): print(o.text)\n", + " nm = getattr(o, 'name', None)\n", + " if nm=='run_cell': print(o.input['code'])\n", + " elif nm: print(f'{o.name}({o.input})')" ] }, { diff --git a/claudette/toolloop.py b/claudette/toolloop.py index 9897906..71edb07 100644 --- a/claudette/toolloop.py +++ b/claudette/toolloop.py @@ -20,11 +20,12 @@ def toolloop(self:Chat, cont_func:Optional[callable]=noop, # Function that stops loop if returns False **kwargs): "Add prompt `pr` to dialog and get a response from Claude, automatically following up with `tool_use` messages" + n_msgs = len(self.h) r = self(pr, **kwargs) for i in range(max_steps): if r.stop_reason!='tool_use': break - if trace_func: trace_func(r) + if trace_func: trace_func(self.h[n_msgs:]); n_msgs = len(self.h) r = self(**kwargs) if not (cont_func or noop)(self.h[-2]): break - if trace_func: trace_func(r) + if trace_func: trace_func(self.h[n_msgs:]) return r