From e67191a4b443d11081f682305e003eda176c3ae4 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Sun, 27 Aug 2023 14:08:21 -0700 Subject: [PATCH] merge update and replace into one function --- chatlab/builtins/noteable.py | 60 ++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/chatlab/builtins/noteable.py b/chatlab/builtins/noteable.py index 5d799d4..01aa2c9 100644 --- a/chatlab/builtins/noteable.py +++ b/chatlab/builtins/noteable.py @@ -174,21 +174,46 @@ async def create_cell( except Exception as e: return f"Cell created successfully. An error happened during run: {e}" - async def update_cell_content(self, cell_id: str, patch: str) -> CodeCell | str: - """Update a cell's content with a `diff-match-patch` formatted patch string.""" + async def update_cell( + self, + cell_id: str, + source: Optional[str] = None, + cell_type: Optional[Literal["code", "markdown", "sql"]] = None, + and_run: Optional[bool] = False, + db_connection: Optional[str] = None, + assign_results_to: Optional[str] = None, + ) -> str: + """Update properties of a cell, including source, cell_type""" rtu_client: RTUClient = await self.get_or_create_rtu_client() try: - return await rtu_client.update_cell_content(cell_id, patch) + if source is not None: + await rtu_client.replace_cell_content(cell_id, source) + _, cell = rtu_client.builder.get_cell(cell_id) + + # Pull the old cell type, as long as it's not a "raw" cell + if cell_type is None and cell.cell_type != 'raw': + cell_type = cell.cell_type + + if cell_type is not None: + conn = db_connection + if conn is None: + # HACK: match the default in origami + conn = "@noteable" + if conn is not None and not conn.startswith("@"): + conn = f"@{conn}" + + await rtu_client.change_cell_type( + cell_id, cell_type, db_connection=conn, assign_results_to=assign_results_to + ) except Exception as e: - return f"Error updating cell content: {e}" + return f"Error replacing cell content: {e}" + if cell.cell_type != "code" or not and_run: + return f"Cell ID `{cell.id}` updated successfully." - async def replace_cell_content(self, cell_id: str, source: str) -> CodeCell | str: - """Replace a cell's content with a string.""" - rtu_client: RTUClient = await self.get_or_create_rtu_client() try: - return await rtu_client.replace_cell_content(cell_id, source) + return await self.run_cell(cell.id) except Exception as e: - return f"Error replacing cell content: {e}" + return f"Cell updated successfully. An error happened during run: {e}" async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID): """Get the outputs for a given output collection ID.""" @@ -212,9 +237,7 @@ async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID): return llm_friendly_outputs async def _extract_llm_plain(self, output: KernelOutput): - resp = await self.api_client.client.get( - f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"} - ) + resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"}) resp.raise_for_status() output_for_llm = KernelOutput.parse_obj(resp.json()) @@ -225,9 +248,7 @@ async def _extract_llm_plain(self, output: KernelOutput): return output_for_llm.content.raw async def _extract_specific_mediatype(self, output: KernelOutput, mimetype: str): - resp = await self.api_client.client.get( - f"/outputs/{output.id}", params={"mimetype": mimetype} - ) + resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": mimetype}) resp.raise_for_status() output_for_llm = KernelOutput.parse_obj(resp.json()) @@ -303,7 +324,7 @@ async def _get_llm_friendly_output(self, output: KernelOutput): if next_best_output.content.raw is not None: return next_best_output.content.raw - async def run_cell(self, cell_id: str): + async def run_cell(self, cell_id: str) -> str: """Run a Cell within a Notebook by ID.""" # Queue up the execution rtu_client = await self.get_or_create_rtu_client() @@ -313,7 +334,7 @@ async def run_cell(self, cell_id: str): if cell.output_collection_id is None: # Hypothesis: if the output collection ID is None, we're in a bad # state. When the LLM sees this cell they will think its fine. - return cell + return "Cell possibly queued for execution." output_collection_id = cell.output_collection_id @@ -322,7 +343,7 @@ async def run_cell(self, cell_id: str): output_collection_id = uuid.UUID(output_collection_id) except ValueError: logger.exception("Invalid UUID", exc_info=True) - return cell + return "Unable to get outputs." outputs = await self._get_llm_friendly_outputs(output_collection_id) response = "" @@ -457,8 +478,7 @@ def chat_functions(self): """Functions to expose for LLMs.""" return [ self.create_cell, - self.update_cell_content, - self.replace_cell_content, + self.update_cell, self.run_cell, self.get_cell, self.get_cell_ids,