diff --git a/pyproject.toml b/pyproject.toml index 31060293..d7a5cff3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ target-version = ["py310"] line-length = 100 target-version = "py310" fix = true +exclude = [ + "src/codegate/db/queries.py", # Ignore auto-generated file from sqlc +] [tool.ruff.lint] select = ["E", "F", "I", "N", "W"] diff --git a/sql/queries/queries.sql b/sql/queries/queries.sql index 9c1e319e..60319e63 100644 --- a/sql/queries/queries.sql +++ b/sql/queries/queries.sql @@ -3,10 +3,9 @@ INSERT INTO prompts ( id, timestamp, provider, - system_prompt, - user_prompt, + request, type -) VALUES (?, ?, ?, ?, ?, ?) RETURNING *; +) VALUES (?, ?, ?, ?, ?) RETURNING *; -- name: GetPrompt :one SELECT * FROM prompts WHERE id = ?; diff --git a/sql/schema/schema.sql b/sql/schema/schema.sql index 059e8ef2..6d8114f8 100644 --- a/sql/schema/schema.sql +++ b/sql/schema/schema.sql @@ -5,8 +5,7 @@ CREATE TABLE prompts ( id TEXT PRIMARY KEY, -- UUID stored as TEXT timestamp DATETIME NOT NULL, provider TEXT, -- VARCHAR(255) - system_prompt TEXT, - user_prompt TEXT NOT NULL, + request TEXT NOT NULL, -- Record the full request that arrived to the server type TEXT NOT NULL -- VARCHAR(50) (e.g. "fim", "chat") ); @@ -15,7 +14,7 @@ CREATE TABLE outputs ( id TEXT PRIMARY KEY, -- UUID stored as TEXT prompt_id TEXT NOT NULL, timestamp DATETIME NOT NULL, - output TEXT NOT NULL, + output TEXT NOT NULL, -- Record the full response. If it was stream will be a list of objects. FOREIGN KEY (prompt_id) REFERENCES prompts(id) ); diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 0e8d5aad..34cb67e5 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,15 +1,18 @@ import asyncio +import copy import datetime +import json import uuid from pathlib import Path -from typing import Optional +from typing import AsyncGenerator, AsyncIterator, Optional import structlog -from litellm import ChatCompletionRequest +from litellm import ChatCompletionRequest, ModelResponse +from pydantic import BaseModel from sqlalchemy import create_engine, text from sqlalchemy.ext.asyncio import create_async_engine -from codegate.db.models import Prompt +from codegate.db.models import Output, Prompt logger = structlog.get_logger("codegate") @@ -68,30 +71,37 @@ async def init_db(self): finally: await self._async_db_engine.dispose() + async def _insert_pydantic_model( + self, model: BaseModel, sql_insert: text + ) -> Optional[BaseModel]: + # There are create method in queries.py automatically generated by sqlc + # However, the methods are buggy for Pydancti and don't work as expected. + # Manually writing the SQL query to insert Pydantic models. + async with self._async_db_engine.begin() as conn: + result = await conn.execute(sql_insert, model.model_dump()) + row = result.first() + if row is None: + return None + + # Get the class of the Pydantic object to create a new object + model_class = model.__class__ + return model_class(**row._asdict()) + async def record_request( self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str ) -> Optional[Prompt]: - # Extract system prompt and user prompt from the messages - messages = normalized_request.get("messages", []) - system_prompt = [] - user_prompt = [] - - for msg in messages: - if msg.get("role") == "system": - system_prompt.append(msg.get("content")) - elif msg.get("role") == "user": - user_prompt.append(msg.get("content")) - - # If no user prompt found in messages, try to get from the prompt field - # (for non-chat completions) - if not user_prompt: - prompt = normalized_request.get("prompt") - if prompt: - user_prompt.append(prompt) - - if not user_prompt: - logger.warning("No user prompt found in request.") - return None + request_str = None + if isinstance(normalized_request, BaseModel): + request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True) + else: + try: + request_str = json.dumps(normalized_request) + except Exception as e: + logger.error(f"Failed to serialize output: {normalized_request}", error=str(e)) + + if request_str is None: + logger.warning("No request found to record.") + return # Create a new prompt record prompt_params = Prompt( @@ -99,33 +109,74 @@ async def record_request( timestamp=datetime.datetime.now(datetime.timezone.utc), provider=provider_str, type="fim" if is_fim_request else "chat", - user_prompt="<|>".join(user_prompt), - system_prompt="<|>".join(system_prompt), + request=request_str, ) - # There is a `create_prompt` method in queries.py automatically generated by sqlc - # However, the method is is buggy and doesn't work as expected. - # Manually writing the SQL query to insert the prompt record. - async with self._async_db_engine.begin() as conn: - sql = text( + sql = text( + """ + INSERT INTO prompts (id, timestamp, provider, request, type) + VALUES (:id, :timestamp, :provider, :request, :type) + RETURNING * """ - INSERT INTO prompts (id, timestamp, provider, system_prompt, user_prompt, type) - VALUES (:id, :timestamp, :provider, :system_prompt, :user_prompt, :type) + ) + return await self._insert_pydantic_model(prompt_params, sql) + + async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]: + output_params = Output( + id=str(uuid.uuid4()), + prompt_id=prompt.id, + timestamp=datetime.datetime.now(datetime.timezone.utc), + output=output_str, + ) + sql = text( + """ + INSERT INTO outputs (id, prompt_id, timestamp, output) + VALUES (:id, :prompt_id, :timestamp, :output) RETURNING * """ - ) - result = await conn.execute(sql, prompt_params.model_dump()) - row = result.first() - if row is None: - return None + ) + return await self._insert_pydantic_model(output_params, sql) + + async def record_output_stream( + self, prompt: Prompt, model_response: AsyncIterator + ) -> AsyncGenerator: + output_chunks = [] + async for chunk in model_response: + if isinstance(chunk, BaseModel): + chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True) + output_chunks.append(chunk_to_record) + elif isinstance(chunk, dict): + output_chunks.append(copy.deepcopy(chunk)) + else: + output_chunks.append({"chunk": str(chunk)}) + yield chunk + + if output_chunks: + # Record the output chunks + output_str = json.dumps(output_chunks) + logger.info(f"Recorded chunks: {output_chunks}. Str: {output_str}") + await self._record_output(prompt, output_str) + + async def record_output_non_stream( + self, prompt: Optional[Prompt], model_response: ModelResponse + ) -> Optional[Output]: + if prompt is None: + logger.warning("No prompt found to record output.") + return + + output_str = None + if isinstance(model_response, BaseModel): + output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True) + else: + try: + output_str = json.dumps(model_response) + except Exception as e: + logger.error(f"Failed to serialize output: {model_response}", error=str(e)) + + if output_str is None: + logger.warning("No output found to record.") + return - return Prompt( - id=row.id, - timestamp=row.timestamp, - provider=row.provider, - system_prompt=row.system_prompt, - user_prompt=row.user_prompt, - type=row.type, - ) + return await self._record_output(prompt, output_str) def init_db_sync(): diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index b5edc0d6..d8b7040f 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -28,8 +28,7 @@ class Prompt(pydantic.BaseModel): id: Any timestamp: Any provider: Optional[Any] - system_prompt: Optional[Any] - user_prompt: Any + request: Any type: Any diff --git a/src/codegate/db/queries.py b/src/codegate/db/queries.py index fdfd0693..eef79462 100644 --- a/src/codegate/db/queries.py +++ b/src/codegate/db/queries.py @@ -20,8 +20,7 @@ trigger_type, trigger_category, timestamp -) VALUES (?, ?, ?, ?, ?, ?, ?, ?) RETURNING id, prompt_id, output_id, code_snippet, -trigger_string, trigger_type, trigger_category, timestamp +) VALUES (?, ?, ?, ?, ?, ?, ?, ?) RETURNING id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, timestamp """ @@ -51,25 +50,14 @@ class CreateAlertParams(pydantic.BaseModel): id, timestamp, provider, - system_prompt, - user_prompt, + request, type -) VALUES (?, ?, ?, ?, ?, ?) RETURNING id, timestamp, provider, system_prompt, user_prompt, type +) VALUES (?, ?, ?, ?, ?) RETURNING id, timestamp, provider, request, type """ -class CreatePromptParams(pydantic.BaseModel): - id: Any - timestamp: Any - provider: Optional[Any] - system_prompt: Optional[Any] - user_prompt: Any - type: Any - - GET_ALERT = """-- name: get_alert \\:one -SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, -timestamp FROM alerts WHERE id = ? +SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, timestamp FROM alerts WHERE id = ? """ @@ -79,20 +67,20 @@ class CreatePromptParams(pydantic.BaseModel): GET_OUTPUTS_BY_PROMPT_ID = """-- name: get_outputs_by_prompt_id \\:many -SELECT id, prompt_id, timestamp, output FROM outputs -WHERE prompt_id = ? +SELECT id, prompt_id, timestamp, output FROM outputs +WHERE prompt_id = ? ORDER BY timestamp DESC """ GET_PROMPT = """-- name: get_prompt \\:one -SELECT id, timestamp, provider, system_prompt, user_prompt, type FROM prompts WHERE id = ? +SELECT id, timestamp, provider, request, type FROM prompts WHERE id = ? """ GET_PROMPT_WITH_OUTPUTS_AND_ALERTS = """-- name: get_prompt_with_outputs_and_alerts \\:many -SELECT - p.id, p.timestamp, p.provider, p.system_prompt, p.user_prompt, p.type, +SELECT + p.id, p.timestamp, p.provider, p.request, p.type, o.id as output_id, o.output, a.id as alert_id, @@ -112,8 +100,7 @@ class GetPromptWithOutputsAndAlertsRow(pydantic.BaseModel): id: Any timestamp: Any provider: Optional[Any] - system_prompt: Optional[Any] - user_prompt: Any + request: Any type: Any output_id: Optional[Any] output: Optional[Any] @@ -130,16 +117,15 @@ class GetPromptWithOutputsAndAlertsRow(pydantic.BaseModel): LIST_ALERTS_BY_PROMPT = """-- name: list_alerts_by_prompt \\:many -SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, -timestamp FROM alerts -WHERE prompt_id = ? +SELECT id, prompt_id, output_id, code_snippet, trigger_string, trigger_type, trigger_category, timestamp FROM alerts +WHERE prompt_id = ? ORDER BY timestamp DESC """ LIST_PROMPTS = """-- name: list_prompts \\:many -SELECT id, timestamp, provider, system_prompt, user_prompt, type FROM prompts -ORDER BY timestamp DESC +SELECT id, timestamp, provider, request, type FROM prompts +ORDER BY timestamp DESC LIMIT ? OFFSET ? """ @@ -224,16 +210,17 @@ def create_output( output=row[3], ) - def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt]: + def create_prompt( + self, *, id: Any, timestamp: Any, provider: Optional[Any], request: Any, type: Any + ) -> Optional[models.Prompt]: row = self._conn.execute( sqlalchemy.text(CREATE_PROMPT), { - "p1": arg.id, - "p2": arg.timestamp, - "p3": arg.provider, - "p4": arg.system_prompt, - "p5": arg.user_prompt, - "p6": arg.type, + "p1": id, + "p2": timestamp, + "p3": provider, + "p4": request, + "p5": type, }, ).first() if row is None: @@ -242,9 +229,8 @@ def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) def get_alert(self, *, id: Any) -> Optional[models.Alert]: @@ -291,9 +277,8 @@ def get_prompt(self, *, id: Any) -> Optional[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) def get_prompt_with_outputs_and_alerts( @@ -305,16 +290,15 @@ def get_prompt_with_outputs_and_alerts( id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], - output_id=row[6], - output=row[7], - alert_id=row[8], - code_snippet=row[9], - trigger_string=row[10], - trigger_type=row[11], - trigger_category=row[12], + request=row[3], + type=row[4], + output_id=row[5], + output=row[6], + alert_id=row[7], + code_snippet=row[8], + trigger_string=row[9], + trigger_type=row[10], + trigger_category=row[11], ) def get_settings(self) -> Optional[models.Setting]: @@ -351,9 +335,8 @@ def list_prompts(self, *, limit: Any, offset: Any) -> Iterator[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) def upsert_settings(self, arg: UpsertSettingsParams) -> Optional[models.Setting]: @@ -436,17 +419,18 @@ async def create_output( output=row[3], ) - async def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt]: + async def create_prompt( + self, *, id: Any, timestamp: Any, provider: Optional[Any], request: Any, type: Any + ) -> Optional[models.Prompt]: row = ( await self._conn.execute( sqlalchemy.text(CREATE_PROMPT), { - "p1": arg.id, - "p2": arg.timestamp, - "p3": arg.provider, - "p4": arg.system_prompt, - "p5": arg.user_prompt, - "p6": arg.type, + "p1": id, + "p2": timestamp, + "p3": provider, + "p4": request, + "p5": type, }, ) ).first() @@ -456,9 +440,8 @@ async def create_prompt(self, arg: CreatePromptParams) -> Optional[models.Prompt id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) async def get_alert(self, *, id: Any) -> Optional[models.Alert]: @@ -507,9 +490,8 @@ async def get_prompt(self, *, id: Any) -> Optional[models.Prompt]: id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) async def get_prompt_with_outputs_and_alerts( @@ -523,16 +505,15 @@ async def get_prompt_with_outputs_and_alerts( id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], - output_id=row[6], - output=row[7], - alert_id=row[8], - code_snippet=row[9], - trigger_string=row[10], - trigger_type=row[11], - trigger_category=row[12], + request=row[3], + type=row[4], + output_id=row[5], + output=row[6], + alert_id=row[7], + code_snippet=row[8], + trigger_string=row[9], + trigger_type=row[10], + trigger_category=row[11], ) async def get_settings(self) -> Optional[models.Setting]: @@ -569,9 +550,8 @@ async def list_prompts(self, *, limit: Any, offset: Any) -> AsyncIterator[models id=row[0], timestamp=row[1], provider=row[2], - system_prompt=row[3], - user_prompt=row[4], - type=row[5], + request=row[3], + type=row[4], ) async def upsert_settings(self, arg: UpsertSettingsParams) -> Optional[models.Setting]: diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 509f8e9c..c56bc4a9 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -175,7 +175,7 @@ async def complete( """ normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - await self.db_recorder.record_request( + prompt_db = await self.db_recorder.record_request( normalized_request, is_fim_request, self.provider_route_name ) @@ -194,10 +194,12 @@ async def complete( provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request ) if not streaming: + await self.db_recorder.record_output_non_stream(prompt_db, model_response) normalized_response = self._output_normalizer.normalize(model_response) pipeline_output = self._run_output_pipeline(normalized_response) return self._output_normalizer.denormalize(pipeline_output) + model_response = self.db_recorder.record_output_stream(prompt_db, model_response) normalized_stream = self._output_normalizer.normalize_streaming(model_response) pipeline_output_stream = await self._run_output_stream_pipeline( input_pipeline_result.context,