From 4d573ae03b9b55fbf4e41fa1d030c87be13e2d96 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 3 Dec 2024 10:38:12 +0200 Subject: [PATCH] Record the output of the LLM in the DB We are going to store the complete output of the LLMs. - If the response was stream then a list of JSON objects will get stored. - If not stream then a single JSON object will get stored. The PR also changes the initial schema. The schema will now also store the complete input request to the LLM. Storing the complete input and output should help debug any problem with CodeGate as well as reproduce faithfully in the dashboard all the conversations. --- pyproject.toml | 3 + sql/queries/queries.sql | 5 +- sql/schema/schema.sql | 5 +- src/codegate/db/connection.py | 143 ++++++++++++++++++++++----------- src/codegate/db/models.py | 3 +- src/codegate/db/queries.py | 140 ++++++++++++++------------------ src/codegate/providers/base.py | 4 +- 7 files changed, 168 insertions(+), 135 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 31060293..d7a5cff3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ target-version = ["py310"] line-length = 100 target-version = "py310" fix = true +exclude = [ + "src/codegate/db/queries.py", # Ignore auto-generated file from sqlc +] [tool.ruff.lint] select = ["E", "F", "I", "N", "W"] diff --git a/sql/queries/queries.sql b/sql/queries/queries.sql index 9c1e319e..60319e63 100644 --- a/sql/queries/queries.sql +++ b/sql/queries/queries.sql @@ -3,10 +3,9 @@ INSERT INTO prompts ( id, timestamp, provider, - system_prompt, - user_prompt, + request, type -) VALUES (?, ?, ?, ?, ?, ?) RETURNING *; +) VALUES (?, ?, ?, ?, ?) RETURNING *; -- name: GetPrompt :one SELECT * FROM prompts WHERE id = ?; diff --git a/sql/schema/schema.sql b/sql/schema/schema.sql index 059e8ef2..6d8114f8 100644 --- a/sql/schema/schema.sql +++ b/sql/schema/schema.sql @@ -5,8 +5,7 @@ CREATE TABLE prompts ( id TEXT PRIMARY KEY, -- UUID stored as TEXT timestamp DATETIME NOT NULL, provider TEXT, -- VARCHAR(255) - system_prompt TEXT, - user_prompt TEXT NOT NULL, + request TEXT NOT NULL, -- Record the full request that arrived to the server type TEXT NOT NULL -- VARCHAR(50) (e.g. "fim", "chat") ); @@ -15,7 +14,7 @@ CREATE TABLE outputs ( id TEXT PRIMARY KEY, -- UUID stored as TEXT prompt_id TEXT NOT NULL, timestamp DATETIME NOT NULL, - output TEXT NOT NULL, + output TEXT NOT NULL, -- Record the full response. If it was stream will be a list of objects. FOREIGN KEY (prompt_id) REFERENCES prompts(id) ); diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 0e8d5aad..34cb67e5 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,15 +1,18 @@ import asyncio +import copy import datetime +import json import uuid from pathlib import Path -from typing import Optional +from typing import AsyncGenerator, AsyncIterator, Optional import structlog -from litellm import ChatCompletionRequest +from litellm import ChatCompletionRequest, ModelResponse +from pydantic import BaseModel from sqlalchemy import create_engine, text from sqlalchemy.ext.asyncio import create_async_engine -from codegate.db.models import Prompt +from codegate.db.models import Output, Prompt logger = structlog.get_logger("codegate") @@ -68,30 +71,37 @@ async def init_db(self): finally: await self._async_db_engine.dispose() + async def _insert_pydantic_model( + self, model: BaseModel, sql_insert: text + ) -> Optional[BaseModel]: + # There are create method in queries.py automatically generated by sqlc + # However, the methods are buggy for Pydancti and don't work as expected. + # Manually writing the SQL query to insert Pydantic models. + async with self._async_db_engine.begin() as conn: + result = await conn.execute(sql_insert, model.model_dump()) + row = result.first() + if row is None: + return None + + # Get the class of the Pydantic object to create a new object + model_class = model.__class__ + return model_class(**row._asdict()) + async def record_request( self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str ) -> Optional[Prompt]: - # Extract system prompt and user prompt from the messages - messages = normalized_request.get("messages", []) - system_prompt = [] - user_prompt = [] - - for msg in messages: - if msg.get("role") == "system": - system_prompt.append(msg.get("content")) - elif msg.get("role") == "user": - user_prompt.append(msg.get("content")) - - # If no user prompt found in messages, try to get from the prompt field - # (for non-chat completions) - if not user_prompt: - prompt = normalized_request.get("prompt") - if prompt: - user_prompt.append(prompt) - - if not user_prompt: - logger.warning("No user prompt found in request.") - return None + request_str = None + if isinstance(normalized_request, BaseModel): + request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True) + else: + try: + request_str = json.dumps(normalized_request) + except Exception as e: + logger.error(f"Failed to serialize output: {normalized_request}", error=str(e)) + + if request_str is None: + logger.warning("No request found to record.") + return # Create a new prompt record prompt_params = Prompt( @@ -99,33 +109,74 @@ async def record_request( timestamp=datetime.datetime.now(datetime.timezone.utc), provider=provider_str, type="fim" if is_fim_request else "chat", - user_prompt="<|>".join(user_prompt), - system_prompt="<|>".join(system_prompt), + request=request_str, ) - # There is a `create_prompt` method in queries.py automatically generated by sqlc - # However, the method is is buggy and doesn't work as expected. - # Manually writing the SQL query to insert the prompt record. - async with self._async_db_engine.begin() as conn: - sql = text( + sql = text( + """ + INSERT INTO prompts (id, timestamp, provider, request, type) + VALUES (:id, :timestamp, :provider, :request, :type) + RETURNING * """ - INSERT INTO prompts (id, timestamp, provider, system_prompt, user_prompt, type) - VALUES (:id, :timestamp, :provider, :system_prompt, :user_prompt, :type) + ) + return await self._insert_pydantic_model(prompt_params, sql) + + async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]: + output_params = Output( + id=str(uuid.uuid4()), + prompt_id=prompt.id, + timestamp=datetime.datetime.now(datetime.timezone.utc), + output=output_str, + ) + sql = text( + """ + INSERT INTO outputs (id, prompt_id, timestamp, output) + VALUES (:id, :prompt_id, :timestamp, :output) RETURNING * """ - ) - result = await conn.execute(sql, prompt_params.model_dump()) - row = result.first() - if row is None: - return None + ) + return await self._insert_pydantic_model(output_params, sql) + + async def record_output_stream( + self, prompt: Prompt, model_response: AsyncIterator + ) -> AsyncGenerator: + output_chunks = [] + async for chunk in model_response: + if isinstance(chunk, BaseModel): + chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True) + output_chunks.append(chunk_to_record) + elif isinstance(chunk, dict): + output_chunks.append(copy.deepcopy(chunk)) + else: + output_chunks.append({"chunk": str(chunk)}) + yield chunk + + if output_chunks: + # Record the output chunks + output_str = json.dumps(output_chunks) + logger.info(f"Recorded chunks: {output_chunks}. Str: {output_str}") + await self._record_output(prompt, output_str) + + async def record_output_non_stream( + self, prompt: Optional[Prompt], model_response: ModelResponse + ) -> Optional[Output]: + if prompt is None: + logger.warning("No prompt found to record output.") + return + + output_str = None + if isinstance(model_response, BaseModel): + output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True) + else: + try: + output_str = json.dumps(model_response) + except Exception as e: + logger.error(f"Failed to serialize output: {model_response}", error=str(e)) + + if output_str is None: + logger.warning("No output found to record.") + return - return Prompt( - id=row.id, - timestamp=row.timestamp, - provider=row.provider, - system_prompt=row.system_prompt, - user_prompt=row.user_prompt, - type=row.type, - ) + return await self._record_output(prompt, output_str) def init_db_sync(): diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index b5edc0d6..d8b7040f 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -28,8 +28,7 @@ class Prompt(pydantic.BaseModel): id: Any timestamp: Any provider: Optional[Any] - system_prompt: Optional[Any] - user_prompt: Any + request: Any type: Any diff --git a/src/codegate/db/queries.py b/src/codegate/db/queries.py index fdfd0693..eef79462 100644 --- a/src/codegate/db/queries.py +++ b/src/codegate/db/queries.py @@ -20,8 +20,7 @@ trigger_type, trigger_category, timestamp -) VALUES (?, ?, ?, ?, ?, ?, ?, ?) RETURNING id, prompt_id, output_id, code_snippet, -trigger_string, trigger_type, trigger_category, timestamp +) VALUES (?, ?, ?, ?, ?, ?, ?, ?) RETURNING id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, timestamp """ @@ -51,25 +50,14 @@ class CreateAlertParams(pydantic.BaseModel): id, timestamp, provider, - system_prompt, - user_prompt, + request, type -) VALUES (?, ?, ?, ?, ?, ?) RETURNING id, timestamp, provider, system_prompt, user_prompt, type +) VALUES (?, ?, ?, ?, ?) RETURNING id, timestamp, provider, request, type """ -class CreatePromptParams(pydantic.BaseModel): - id: Any - timestamp: Any - provider: Optional[Any] - system_prompt: Optional[Any] - user_prompt: Any - type: Any - - GET_ALERT = """-- name: get_alert \\:one -SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, -timestamp FROM alerts WHERE id = ? +SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, timestamp FROM alerts WHERE id = ? """ @@ -79,20 +67,20 @@ class CreatePromptParams(pydantic.BaseModel): GET_OUTPUTS_BY_PROMPT_ID = """-- name: get_outputs_by_prompt_id \\:many -SELECT id, prompt_id, timestamp, output FROM outputs -WHERE prompt_id = ? +SELECT id, prompt_id, timestamp, output FROM outputs +WHERE prompt_id = ? ORDER BY timestamp DESC """ GET_PROMPT = """-- name: get_prompt \\:one -SELECT id, timestamp, provider, system_prompt, user_prompt, type FROM prompts WHERE id = ? +SELECT id, timestamp, provider, request, type FROM prompts WHERE id = ? """ GET_PROMPT_WITH_OUTPUTS_AND_ALERTS = """-- name: get_prompt_with_outputs_and_alerts \\:many -SELECT - p.id, p.timestamp, p.provider, p.system_prompt, p.user_prompt, p.type, +SELECT + p.id, p.timestamp, p.provider, p.request, p.type, o.id as output_id, o.output, a.id as alert_id, @@ -112,8 +100,7 @@ class GetPromptWithOutputsAndAlertsRow(pydantic.BaseModel): id: Any timestamp: Any provider: Optional[Any] - system_prompt: Optional[Any] - user_prompt: Any + request: Any type: Any output_id: Optional[Any] output: Optional[Any] @@ -130,16 +117,15 @@ class GetPromptWithOutputsAndAlertsRow(pydantic.BaseModel): LIST_ALERTS_BY_PROMPT = """-- name: list_alerts_by_prompt \\:many -SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, -timestamp FROM alerts -WHERE prompt_id = ? +SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, timestamp FROM alerts +WHERE prompt_id = ? ORDER BY timestamp DESC """ LIST_PROMPTS = """-- name: list_prompts \\:many -SELECT id, timestamp, provider, system_prompt, user_prompt, type FROM prompts -ORDER BY timestamp DESC +SELECT id, timestamp, provider, request, type FROM prompts +ORDER BY timestamp DESC LIMIT ? OFFSET ? """ @@ -224,16 +210,17 @@ def create_output( output=row[3], ) - def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt]: + def create_prompt( + self, *, id: Any, timestamp: Any, provider: Optional[Any], request: Any, type: Any + ) -> Optional[models.Prompt]: row = self._conn.execute( sqlalchemy.text(CREATE_PROMPT), { - "p1": arg.id, - "p2": arg.timestamp, - "p3": arg.provider, - "p4": arg.system_prompt, - "p5": arg.user_prompt, - "p6": arg.type, + "p1": id, + "p2": timestamp, + "p3": provider, + "p4": request, + "p5": type, }, ).first() if row is None: @@ -242,9 +229,8 @@ def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) def get_alert(self, *, id: Any) -> Optional[models.Alert]: @@ -291,9 +277,8 @@ def get_prompt(self, *, id: Any) -> Optional[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) def get_prompt_with_outputs_and_alerts( @@ -305,16 +290,15 @@ def get_prompt_with_outputs_and_alerts( id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], - output_id=row[6], - output=row[7], - alert_id=row[8], - code_snippet=row[9], - trigger_string=row[10], - trigger_type=row[11], - trigger_category=row[12], + request=row[3], + type=row[4], + output_id=row[5], + output=row[6], + alert_id=row[7], + code_snippet=row[8], + trigger_string=row[9], + trigger_type=row[10], + trigger_category=row[11], ) def get_settings(self) -> Optional[models.Setting]: @@ -351,9 +335,8 @@ def list_prompts(self, *, limit: Any, offset: Any) -> Iterator[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) def upsert_settings(self, arg: UpsertSettingsParams) -> Optional[models.Setting]: @@ -436,17 +419,18 @@ async def create_output( output=row[3], ) - async def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt]: + async def create_prompt( + self, *, id: Any, timestamp: Any, provider: Optional[Any], request: Any, type: Any + ) -> Optional[models.Prompt]: row = ( await self._conn.execute( sqlalchemy.text(CREATE_PROMPT), { - "p1": arg.id, - "p2": arg.timestamp, - "p3": arg.provider, - "p4": arg.system_prompt, - "p5": arg.user_prompt, - "p6": arg.type, + "p1": id, + "p2": timestamp, + "p3": provider, + "p4": request, + "p5": type, }, ) ).first() @@ -456,9 +440,8 @@ async def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) async def get_alert(self, *, id: Any) -> Optional[models.Alert]: @@ -507,9 +490,8 @@ async def get_prompt(self, *, id: Any) -> Optional[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) async def get_prompt_with_outputs_and_alerts( @@ -523,16 +505,15 @@ async def get_prompt_with_outputs_and_alerts( id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], - output_id=row[6], - output=row[7], - alert_id=row[8], - code_snippet=row[9], - trigger_string=row[10], - trigger_type=row[11], - trigger_category=row[12], + request=row[3], + type=row[4], + output_id=row[5], + output=row[6], + alert_id=row[7], + code_snippet=row[8], + trigger_string=row[9], + trigger_type=row[10], + trigger_category=row[11], ) async def get_settings(self) -> Optional[models.Setting]: @@ -569,9 +550,8 @@ async def list_prompts(self, *, limit: Any, offset: Any) -> AsyncIterator[models id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) async def upsert_settings(self, arg: UpsertSettingsParams) -> Optional[models.Setting]: diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 509f8e9c..c56bc4a9 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -175,7 +175,7 @@ async def complete( """ normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - await self.db_recorder.record_request( + prompt_db = await self.db_recorder.record_request( normalized_request, is_fim_request, self.provider_route_name ) @@ -194,10 +194,12 @@ async def complete( provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request ) if not streaming: + await self.db_recorder.record_output_non_stream(prompt_db, model_response) normalized_response = self._output_normalizer.normalize(model_response) pipeline_output = self._run_output_pipeline(normalized_response) return self._output_normalizer.denormalize(pipeline_output) + model_response = self.db_recorder.record_output_stream(prompt_db, model_response) normalized_stream = self._output_normalizer.normalize_streaming(model_response) pipeline_output_stream = await self._run_output_stream_pipeline( input_pipeline_result.context,