Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: FIM not caching correctly non-python files #408

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/codegate/codegate_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _missing_(cls, value: str) -> Optional["LogFormat"]:

def add_origin(logger, log_method, event_dict):
# Add 'origin' if it's bound to the logger but not explicitly in the event dict
if 'origin' not in event_dict and hasattr(logger, '_context'):
origin = logger._context.get('origin')
if "origin" not in event_dict and hasattr(logger, "_context"):
origin = logger._context.get("origin")
if origin:
event_dict['origin'] = origin
event_dict["origin"] = origin
return event_dict


Expand Down
89 changes: 15 additions & 74 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import asyncio
import hashlib
import json
import re
from datetime import timedelta
from pathlib import Path
from typing import List, Optional

Expand All @@ -11,7 +8,7 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.config import Config
from codegate.db.fim_cache import FimCache
from codegate.db.models import Alert, Output, Prompt
from codegate.db.queries import (
AsyncQuerier,
Expand All @@ -22,7 +19,7 @@

logger = structlog.get_logger("codegate")
alert_queue = asyncio.Queue()
fim_entries = {}
fim_cache = FimCache()


class DbCodeGate:
Expand Down Expand Up @@ -183,47 +180,6 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
logger.debug(f"Recorded alerts: {recorded_alerts}")
return recorded_alerts

def _extract_request_message(self, request: str) -> Optional[dict]:
"""Extract the user message from the FIM request"""
try:
parsed_request = json.loads(request)
except Exception as e:
logger.exception(f"Failed to extract request message: {request}", error=str(e))
return None

messages = [message for message in parsed_request["messages"] if message["role"] == "user"]
if len(messages) != 1:
logger.warning(f"Expected one user message, found {len(messages)}.")
return None

content_message = messages[0].get("content")
return content_message

def _create_hash_key(self, message: str, provider: str) -> str:
"""Creates a hash key from the message and includes the provider"""
# Try to extract the path from the FIM message. The path is in FIM request in these formats:
# folder/testing_file.py
# Path: file3.py
pattern = r"^#.*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
matches = re.findall(pattern, message, re.MULTILINE)
# If no path is found, hash the entire prompt message.
if not matches:
logger.warning("No path found in messages. Creating hash cache from message.")
message_to_hash = f"{message}-{provider}"
else:
# Copilot puts the path at the top of the file. Continue providers contain
# several paths, the one in which the fim is triggered is the last one.
if provider == "copilot":
filepath = matches[0]
else:
filepath = matches[-1]
message_to_hash = f"{filepath}-{provider}"

logger.debug(f"Message to hash: {message_to_hash}")
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
logger.debug(f"Hashed contnet: {hashed_content}")
return hashed_content

def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
"""Check if the context should be recorded in DB"""
if context is None or context.metadata.get("stored_in_db", False):
Expand All @@ -237,37 +193,22 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
if context.input_request.type != "fim":
return True

# Couldn't process the user message. Skip creating a mapping entry.
message = self._extract_request_message(context.input_request.request)
if message is None:
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
return False

hash_key = self._create_hash_key(message, context.input_request.provider)
old_timestamp = fim_entries.get(hash_key, None)
if old_timestamp is None:
fim_entries[hash_key] = context.input_request.timestamp
return True
return fim_cache.could_store_fim_request(context)

elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds()
if elapsed_seconds < Config.get_config().max_fim_hash_lifetime:
async def record_context(self, context: Optional[PipelineContext]) -> None:
try:
if not self._should_record_context(context):
return
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses)
await self.record_alerts(context.alerts_raised)
context.metadata["stored_in_db"] = True
logger.info(
f"Skipping DB context recording. "
f"Elapsed time since last FIM cache: {timedelta(seconds=elapsed_seconds)}."
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
return False

async def record_context(self, context: Optional[PipelineContext]) -> None:
if not self._should_record_context(context):
return
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses)
await self.record_alerts(context.alerts_raised)
context.metadata["stored_in_db"] = True
logger.info(
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))


class DbReader(DbCodeGate):
Expand Down
136 changes: 136 additions & 0 deletions src/codegate/db/fim_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import datetime
import hashlib
import json
import re
from typing import Dict, List, Optional

import structlog
from pydantic import BaseModel

from codegate.config import Config
from codegate.db.models import Alert
from codegate.pipeline.base import AlertSeverity, PipelineContext

logger = structlog.get_logger("codegate")


class CachedFim(BaseModel):

timestamp: datetime.datetime
critical_alerts: List[Alert]


class FimCache:

def __init__(self):
self.cache: Dict[str, CachedFim] = {}

def _extract_message_from_fim_request(self, request: str) -> Optional[str]:
"""Extract the user message from the FIM request"""
try:
parsed_request = json.loads(request)
except Exception as e:
logger.error(f"Failed to extract request message: {request}", error=str(e))
return None

if not isinstance(parsed_request, dict):
logger.warning(f"Expected a dictionary, got {type(parsed_request)}.")
return None

messages = [
message
for message in parsed_request.get("messages", [])
if isinstance(message, dict) and message.get("role", "") == "user"
]
if len(messages) != 1:
logger.warning(f"Expected one user message, found {len(messages)}.")
return None

content_message = messages[0].get("content")
return content_message

def _match_filepath(self, message: str, provider: str) -> Optional[str]:
# Try to extract the path from the FIM message. The path is in FIM request as a comment:
# folder/testing_file.py
# Path: file3.py
# // Path: file3.js <-- Javascript
pattern = r"^(#|//|<!--|--|%|;).*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
matches = re.findall(pattern, message, re.MULTILINE)
# If no path is found, hash the entire prompt message.
if not matches:
return None

# Extract only the paths (2nd group from the match)
paths = [match[1] for match in matches]

# Copilot puts the path at the top of the file. Continue providers contain
# several paths, the one in which the fim is triggered is the last one.
if provider == "copilot":
aponcedeleonch marked this conversation as resolved.
Show resolved Hide resolved
return paths[0]
else:
return paths[-1]

def _calculate_hash_key(self, message: str, provider: str) -> str:
"""Creates a hash key from the message and includes the provider"""
filepath = self._match_filepath(message, provider)
if filepath is None:
logger.warning("No path found in messages. Creating hash key from message.")
message_to_hash = f"{message}-{provider}"
else:
message_to_hash = f"{filepath}-{provider}"

logger.debug(f"Message to hash: {message_to_hash}")
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
logger.debug(f"Hashed content: {hashed_content}")
return hashed_content

def _add_cache_entry(self, hash_key: str, context: PipelineContext):
"""Add a new cache entry"""
critical_alerts = [
alert
for alert in context.alerts_raised
if alert.trigger_category == AlertSeverity.CRITICAL.value
]
new_cache = CachedFim(
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts
)
self.cache[hash_key] = new_cache
logger.info(f"Added cache entry for hash key: {hash_key}")

def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
"""Check if there are new alerts present"""
new_critical_alerts = [
aponcedeleonch marked this conversation as resolved.
Show resolved Hide resolved
alert
for alert in context.alerts_raised
if alert.trigger_category == AlertSeverity.CRITICAL.value
]
return len(new_critical_alerts) > len(cached_entry.critical_alerts)

def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
"""Check if the cached entry is old"""
elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds()
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime

def could_store_fim_request(self, context: PipelineContext):
# Couldn't process the user message. Skip creating a mapping entry.
message = self._extract_message_from_fim_request(context.input_request.request)
if message is None:
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
return False

hash_key = self._calculate_hash_key(message, context.input_request.provider)
cached_entry = self.cache.get(hash_key, None)
if cached_entry is None:
self._add_cache_entry(hash_key, context)
return True

if self._is_cached_entry_old(context, cached_entry):
self._add_cache_entry(hash_key, context)
return True

if self._are_new_alerts_present(context, cached_entry):
self._add_cache_entry(hash_key, context)
return True

logger.debug(f"FIM entry already in cache: {hash_key}.")
return False
18 changes: 8 additions & 10 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
import re
import ssl
from src.codegate.codegate_logging import setup_logging
import structlog
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import unquote, urljoin, urlparse

import structlog
from litellm.types.utils import Delta, ModelResponse, StreamingChoices

from codegate.ca.codegate_ca import CertificateAuthority
Expand All @@ -22,6 +21,7 @@
CopilotPipeline,
)
from codegate.providers.copilot.streaming import SSEProcessor
from src.codegate.codegate_logging import setup_logging

setup_logging()
logger = structlog.get_logger("codegate").bind(origin="copilot_proxy")
Expand Down Expand Up @@ -206,7 +206,7 @@ async def _request_to_target(self, headers: list[str], body: bytes):
logger.debug("=" * 40)

for i in range(0, len(body), CHUNK_SIZE):
chunk = body[i: i + CHUNK_SIZE]
chunk = body[i : i + CHUNK_SIZE]
self.target_transport.write(chunk)

def connection_made(self, transport: asyncio.Transport) -> None:
Expand Down Expand Up @@ -269,9 +269,7 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
"""Check if adding new data would exceed buffer size limit"""
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE

async def _forward_data_through_pipeline(
self, data: bytes
) -> Union[HttpRequest, HttpResponse]:
async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest, HttpResponse]:
http_request = http_request_from_bytes(data)
if not http_request:
# we couldn't parse this into an HTTP request, so we just pass through
Expand All @@ -287,7 +285,7 @@ async def _forward_data_through_pipeline(

if context and context.shortcut_response:
# Send shortcut response
data_prefix = b'data:'
data_prefix = b"data:"
http_response = HttpResponse(
http_request.version,
200,
Expand All @@ -299,7 +297,7 @@ async def _forward_data_through_pipeline(
"Content-Type: application/json",
"Transfer-Encoding: chunked",
],
data_prefix + body
data_prefix + body,
)
return http_response

Expand Down Expand Up @@ -639,7 +637,7 @@ async def get_target_url(path: str) -> Optional[str]:
# Check for prefix match
for route in VALIDATED_ROUTES:
# For prefix matches, keep the rest of the path
remaining_path = path[len(route.path):]
remaining_path = path[len(route.path) :]
logger.debug(f"Remaining path: {remaining_path}")
# Make sure we don't end up with double slashes
if remaining_path and remaining_path.startswith("/"):
Expand Down Expand Up @@ -793,7 +791,7 @@ def data_received(self, data: bytes) -> None:
self._proxy_transport_write(headers)
logger.debug(f"Headers sent: {headers}")

data = data[header_end + 4:]
data = data[header_end + 4 :]

self._process_chunk(data)

Expand Down
36 changes: 0 additions & 36 deletions tests/db/test_connection.py

This file was deleted.

Loading
Loading