Skip to content

Commit

Permalink
Record the output of the LLM in the DB
Browse files Browse the repository at this point in the history
We are going to store the complete output of the LLMs.
- If the response was stream then a list of JSON objects will get stored.
- If not stream then a single JSON object will get stored.

The PR also changes the initial schema. The schema will now also
store the complete input request to the LLM.

Storing the complete input and output should help debug any problem
with CodeGate as well as reproduce faithfully in the dashboard all
the conversations.
  • Loading branch information
aponcedeleonch committed Dec 3, 2024
1 parent 47c4330 commit 4d573ae
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 135 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ target-version = ["py310"]
line-length = 100
target-version = "py310"
fix = true
exclude = [
"src/codegate/db/queries.py", # Ignore auto-generated file from sqlc
]

[tool.ruff.lint]
select = ["E", "F", "I", "N", "W"]
Expand Down
5 changes: 2 additions & 3 deletions sql/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ INSERT INTO prompts (
id,
timestamp,
provider,
system_prompt,
user_prompt,
request,
type
) VALUES (?, ?, ?, ?, ?, ?) RETURNING *;
) VALUES (?, ?, ?, ?, ?) RETURNING *;

-- name: GetPrompt :one
SELECT * FROM prompts WHERE id = ?;
Expand Down
5 changes: 2 additions & 3 deletions sql/schema/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ CREATE TABLE prompts (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
timestamp DATETIME NOT NULL,
provider TEXT, -- VARCHAR(255)
system_prompt TEXT,
user_prompt TEXT NOT NULL,
request TEXT NOT NULL, -- Record the full request that arrived to the server
type TEXT NOT NULL -- VARCHAR(50) (e.g. "fim", "chat")
);

Expand All @@ -15,7 +14,7 @@ CREATE TABLE outputs (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
prompt_id TEXT NOT NULL,
timestamp DATETIME NOT NULL,
output TEXT NOT NULL,
output TEXT NOT NULL, -- Record the full response. If it was stream will be a list of objects.
FOREIGN KEY (prompt_id) REFERENCES prompts(id)
);

Expand Down
143 changes: 97 additions & 46 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import asyncio
import copy
import datetime
import json
import uuid
from pathlib import Path
from typing import Optional
from typing import AsyncGenerator, AsyncIterator, Optional

import structlog
from litellm import ChatCompletionRequest
from litellm import ChatCompletionRequest, ModelResponse
from pydantic import BaseModel
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.models import Prompt
from codegate.db.models import Output, Prompt

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -68,64 +71,112 @@ async def init_db(self):
finally:
await self._async_db_engine.dispose()

async def _insert_pydantic_model(
self, model: BaseModel, sql_insert: text
) -> Optional[BaseModel]:
# There are create method in queries.py automatically generated by sqlc
# However, the methods are buggy for Pydancti and don't work as expected.
# Manually writing the SQL query to insert Pydantic models.
async with self._async_db_engine.begin() as conn:
result = await conn.execute(sql_insert, model.model_dump())
row = result.first()
if row is None:
return None

# Get the class of the Pydantic object to create a new object
model_class = model.__class__
return model_class(**row._asdict())

async def record_request(
self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str
) -> Optional[Prompt]:
# Extract system prompt and user prompt from the messages
messages = normalized_request.get("messages", [])
system_prompt = []
user_prompt = []

for msg in messages:
if msg.get("role") == "system":
system_prompt.append(msg.get("content"))
elif msg.get("role") == "user":
user_prompt.append(msg.get("content"))

# If no user prompt found in messages, try to get from the prompt field
# (for non-chat completions)
if not user_prompt:
prompt = normalized_request.get("prompt")
if prompt:
user_prompt.append(prompt)

if not user_prompt:
logger.warning("No user prompt found in request.")
return None
request_str = None
if isinstance(normalized_request, BaseModel):
request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True)
else:
try:
request_str = json.dumps(normalized_request)
except Exception as e:
logger.error(f"Failed to serialize output: {normalized_request}", error=str(e))

if request_str is None:
logger.warning("No request found to record.")
return

# Create a new prompt record
prompt_params = Prompt(
id=str(uuid.uuid4()), # Generate a new UUID for the prompt
timestamp=datetime.datetime.now(datetime.timezone.utc),
provider=provider_str,
type="fim" if is_fim_request else "chat",
user_prompt="<|>".join(user_prompt),
system_prompt="<|>".join(system_prompt),
request=request_str,
)
# There is a `create_prompt` method in queries.py automatically generated by sqlc
# However, the method is is buggy and doesn't work as expected.
# Manually writing the SQL query to insert the prompt record.
async with self._async_db_engine.begin() as conn:
sql = text(
sql = text(
"""
INSERT INTO prompts (id, timestamp, provider, request, type)
VALUES (:id, :timestamp, :provider, :request, :type)
RETURNING *
"""
INSERT INTO prompts (id, timestamp, provider, system_prompt, user_prompt, type)
VALUES (:id, :timestamp, :provider, :system_prompt, :user_prompt, :type)
)
return await self._insert_pydantic_model(prompt_params, sql)

async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]:
output_params = Output(
id=str(uuid.uuid4()),
prompt_id=prompt.id,
timestamp=datetime.datetime.now(datetime.timezone.utc),
output=output_str,
)
sql = text(
"""
INSERT INTO outputs (id, prompt_id, timestamp, output)
VALUES (:id, :prompt_id, :timestamp, :output)
RETURNING *
"""
)
result = await conn.execute(sql, prompt_params.model_dump())
row = result.first()
if row is None:
return None
)
return await self._insert_pydantic_model(output_params, sql)

async def record_output_stream(
self, prompt: Prompt, model_response: AsyncIterator
) -> AsyncGenerator:
output_chunks = []
async for chunk in model_response:
if isinstance(chunk, BaseModel):
chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True)
output_chunks.append(chunk_to_record)
elif isinstance(chunk, dict):
output_chunks.append(copy.deepcopy(chunk))
else:
output_chunks.append({"chunk": str(chunk)})
yield chunk

if output_chunks:
# Record the output chunks
output_str = json.dumps(output_chunks)
logger.info(f"Recorded chunks: {output_chunks}. Str: {output_str}")
await self._record_output(prompt, output_str)

async def record_output_non_stream(
self, prompt: Optional[Prompt], model_response: ModelResponse
) -> Optional[Output]:
if prompt is None:
logger.warning("No prompt found to record output.")
return

output_str = None
if isinstance(model_response, BaseModel):
output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True)
else:
try:
output_str = json.dumps(model_response)
except Exception as e:
logger.error(f"Failed to serialize output: {model_response}", error=str(e))

if output_str is None:
logger.warning("No output found to record.")
return

return Prompt(
id=row.id,
timestamp=row.timestamp,
provider=row.provider,
system_prompt=row.system_prompt,
user_prompt=row.user_prompt,
type=row.type,
)
return await self._record_output(prompt, output_str)


def init_db_sync():
Expand Down
3 changes: 1 addition & 2 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class Prompt(pydantic.BaseModel):
id: Any
timestamp: Any
provider: Optional[Any]
system_prompt: Optional[Any]
user_prompt: Any
request: Any
type: Any


Expand Down
Loading

0 comments on commit 4d573ae

Please sign in to comment.