Skip to content

Commit

Permalink
merge update and replace into one function
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Aug 27, 2023
1 parent 3c8fd42 commit e67191a
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions chatlab/builtins/noteable.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,46 @@ async def create_cell(
except Exception as e:
return f"Cell created successfully. An error happened during run: {e}"

async def update_cell_content(self, cell_id: str, patch: str) -> CodeCell | str:
"""Update a cell's content with a `diff-match-patch` formatted patch string."""
async def update_cell(
self,
cell_id: str,
source: Optional[str] = None,
cell_type: Optional[Literal["code", "markdown", "sql"]] = None,
and_run: Optional[bool] = False,
db_connection: Optional[str] = None,
assign_results_to: Optional[str] = None,
) -> str:
"""Update properties of a cell, including source, cell_type"""
rtu_client: RTUClient = await self.get_or_create_rtu_client()
try:
return await rtu_client.update_cell_content(cell_id, patch)
if source is not None:
await rtu_client.replace_cell_content(cell_id, source)
_, cell = rtu_client.builder.get_cell(cell_id)

# Pull the old cell type, as long as it's not a "raw" cell
if cell_type is None and cell.cell_type != 'raw':
cell_type = cell.cell_type

if cell_type is not None:
conn = db_connection
if conn is None:
# HACK: match the default in origami
conn = "@noteable"
if conn is not None and not conn.startswith("@"):
conn = f"@{conn}"

await rtu_client.change_cell_type(
cell_id, cell_type, db_connection=conn, assign_results_to=assign_results_to
)
except Exception as e:
return f"Error updating cell content: {e}"
return f"Error replacing cell content: {e}"
if cell.cell_type != "code" or not and_run:
return f"Cell ID `{cell.id}` updated successfully."

async def replace_cell_content(self, cell_id: str, source: str) -> CodeCell | str:
"""Replace a cell's content with a string."""
rtu_client: RTUClient = await self.get_or_create_rtu_client()
try:
return await rtu_client.replace_cell_content(cell_id, source)
return await self.run_cell(cell.id)
except Exception as e:
return f"Error replacing cell content: {e}"
return f"Cell updated successfully. An error happened during run: {e}"

async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID):
"""Get the outputs for a given output collection ID."""
Expand All @@ -212,9 +237,7 @@ async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID):
return llm_friendly_outputs

async def _extract_llm_plain(self, output: KernelOutput):
resp = await self.api_client.client.get(
f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"}
)
resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"})
resp.raise_for_status()

output_for_llm = KernelOutput.parse_obj(resp.json())
Expand All @@ -225,9 +248,7 @@ async def _extract_llm_plain(self, output: KernelOutput):
return output_for_llm.content.raw

async def _extract_specific_mediatype(self, output: KernelOutput, mimetype: str):
resp = await self.api_client.client.get(
f"/outputs/{output.id}", params={"mimetype": mimetype}
)
resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": mimetype})
resp.raise_for_status()

output_for_llm = KernelOutput.parse_obj(resp.json())
Expand Down Expand Up @@ -303,7 +324,7 @@ async def _get_llm_friendly_output(self, output: KernelOutput):
if next_best_output.content.raw is not None:
return next_best_output.content.raw

async def run_cell(self, cell_id: str):
async def run_cell(self, cell_id: str) -> str:
"""Run a Cell within a Notebook by ID."""
# Queue up the execution
rtu_client = await self.get_or_create_rtu_client()
Expand All @@ -313,7 +334,7 @@ async def run_cell(self, cell_id: str):
if cell.output_collection_id is None:
# Hypothesis: if the output collection ID is None, we're in a bad
# state. When the LLM sees this cell they will think its fine.
return cell
return "Cell possibly queued for execution."

output_collection_id = cell.output_collection_id

Expand All @@ -322,7 +343,7 @@ async def run_cell(self, cell_id: str):
output_collection_id = uuid.UUID(output_collection_id)
except ValueError:
logger.exception("Invalid UUID", exc_info=True)
return cell
return "Unable to get outputs."

outputs = await self._get_llm_friendly_outputs(output_collection_id)
response = ""
Expand Down Expand Up @@ -457,8 +478,7 @@ def chat_functions(self):
"""Functions to expose for LLMs."""
return [
self.create_cell,
self.update_cell_content,
self.replace_cell_content,
self.update_cell,
self.run_cell,
self.get_cell,
self.get_cell_ids,
Expand Down

0 comments on commit e67191a

Please sign in to comment.