Skip to content

Commit

Permalink
Merge pull request #102 from stacklok/add-fim-pipeline
Browse files Browse the repository at this point in the history
Add a FIM pipeline to Providers
  • Loading branch information
aponcedeleonch authored Nov 28, 2024
2 parents f005ecc + 8bb074c commit 189aee9
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 24 deletions.
52 changes: 52 additions & 0 deletions src/codegate/pipeline/fim/secret_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from litellm import ChatCompletionRequest

from codegate.codegate_logging import setup_logging
from codegate.pipeline.base import PipelineContext, PipelineResponse, PipelineResult, PipelineStep

logger = setup_logging()


class SecretAnalyzer(PipelineStep):
"""Pipeline step that handles analyzing secrets in FIM pipeline."""

message_blocked = """
⚠️ CodeGate Security Warning! Analysis Report ⚠️
Potential leak of sensitive credentials blocked
Recommendations:
- Use environment variables for secrets
"""

@property
def name(self) -> str:
"""
Returns the name of this pipeline step.
Returns:
str: The identifier 'fim-secret-analyzer'
"""
return "fim-secret-analyzer"

async def process(
self,
request: ChatCompletionRequest,
context: PipelineContext
) -> PipelineResult:
# We should call here Secrets Blocking module to see if the request messages contain secrets
# messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages]
# message_with_secrets = any(messages_contain_secretes)

# For the moment to test shortcutting just treat all messages as if they contain secrets
message_with_secrets = False
if message_with_secrets:
logger.info('Blocking message with secrets.')
return PipelineResult(
response=PipelineResponse(
step_name=self.name,
content=self.message_blocked,
model=request["model"],
),
)

# No messages with secrets, execute the rest of the pipeline
return PipelineResult(request=request)
33 changes: 33 additions & 0 deletions src/codegate/providers/anthropic/completion_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import AsyncIterator, Optional, Union

from litellm import ChatCompletionRequest, ModelResponse

from codegate.providers.litellmshim import LiteLLmShim


class AnthropicCompletion(LiteLLmShim):
"""
AnthropicCompletion used by the Anthropic provider to execute completions
"""

async def execute_completion(
self,
request: ChatCompletionRequest,
api_key: Optional[str],
stream: bool = False,
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API.
LiteLLM automatically maps most model names, but prepending 'anthropic/' forces the request
to Anthropic. This avoids issues with unrecognized names like 'claude-3-5-sonnet-latest',
which LiteLLM doesn't accept as a valid Anthropic model. This safeguard may be unnecessary
but ensures compatibility.
For more details, refer to the
[LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic).
"""
model_in_request = request['model']
if not model_in_request.startswith('anthropic/'):
request['model'] = f'anthropic/{model_in_request}'
return await super().execute_completion(request, api_key, stream)
18 changes: 13 additions & 5 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import json
from typing import Optional

from fastapi import Header, HTTPException, Request

from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.base import BaseProvider
from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
from codegate.providers.litellmshim import anthropic_stream_generator


class AnthropicProvider(BaseProvider):
def __init__(self, pipeline_processor=None):
completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator)
def __init__(
self,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
):
completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
super().__init__(
AnthropicInputNormalizer(),
AnthropicOutputNormalizer(),
completion_handler,
pipeline_processor,
fim_pipeline_processor
)

@property
Expand All @@ -39,5 +46,6 @@ async def create_message(
body = await request.body()
data = json.loads(body)

stream = await self.complete(data, x_api_key)
is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, x_api_key, is_fim_request)
return self._completion_handler.create_streaming_response(stream)
74 changes: 65 additions & 9 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Callable, Dict, Optional, Union

from fastapi import APIRouter
from fastapi import APIRouter, Request
from litellm import ModelResponse
from litellm.types.llms.openai import ChatCompletionRequest

from codegate.codegate_logging import setup_logging
from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor
from codegate.providers.completion.base import BaseCompletionHandler
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer

logger = setup_logging()
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]


Expand All @@ -25,12 +27,14 @@ def __init__(
output_normalizer: ModelOutputNormalizer,
completion_handler: BaseCompletionHandler,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
):
self.router = APIRouter()
self._completion_handler = completion_handler
self._input_normalizer = input_normalizer
self._output_normalizer = output_normalizer
self._pipeline_processor = pipeline_processor
self._fim_pipelin_processor = fim_pipeline_processor

self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer)

Expand All @@ -48,22 +52,76 @@ def provider_route_name(self) -> str:
async def _run_input_pipeline(
self,
normalized_request: ChatCompletionRequest,
is_fim_request: bool
) -> PipelineResult:
if self._pipeline_processor is None:
# Decide which pipeline processor to use
if is_fim_request:
pipeline_processor = self._fim_pipelin_processor
logger.info('FIM pipeline selected for execution.')
else:
pipeline_processor = self._pipeline_processor
logger.info('Chat completion pipeline selected for execution.')
if pipeline_processor is None:
return PipelineResult(request=normalized_request)

result = await self._pipeline_processor.process_request(normalized_request)
result = await pipeline_processor.process_request(normalized_request)

# TODO(jakub): handle this by returning a message to the client
if result.error_message:
raise Exception(result.error_message)

return result

def _is_fim_request_url(self, request: Request) -> bool:
"""
Checks the request URL to determine if a request is FIM or chat completion.
Used by: llama.cpp
"""
request_path = request.url.path
# Evaluate first a larger substring.
if request_path.endswith("/chat/completions"):
return False

if request_path.endswith("/completions"):
return True

return False

def _is_fim_request_body(self, data: Dict) -> bool:
"""
Determine from the raw incoming data if it's a FIM request.
Used by: OpenAI and Anthropic
"""
messages = data.get('messages', [])
if not messages:
return False

first_message_content = messages[0].get('content')
if first_message_content is None:
return False

fim_stop_sequences = ['</COMPLETION>', '<COMPLETION>', '</QUERY>', '<QUERY>']
if isinstance(first_message_content, str):
msg_prompt = first_message_content
elif isinstance(first_message_content, list):
msg_prompt = first_message_content[0].get('text', '')
else:
logger.warning(f'Could not determine if message was FIM from data: {data}')
return False
return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])

def _is_fim_request(self, request: Request, data: Dict) -> bool:
"""
Determin if the request is FIM by the URL or the data of the request.
"""
# Avoid more expensive inspection of body by just checking the URL.
if self._is_fim_request_url(request):
return True

return self._is_fim_request_body(data)

async def complete(
self,
data: Dict,
api_key: Optional[str],
self, data: Dict, api_key: Optional[str], is_fim_request: bool
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Main completion flow with pipeline integration
Expand All @@ -78,8 +136,7 @@ async def complete(
"""
normalized_request = self._input_normalizer.normalize(data)
streaming = data.get("stream", False)

input_pipeline_result = await self._run_input_pipeline(normalized_request)
input_pipeline_result = await self._run_input_pipeline(normalized_request, is_fim_request)
if input_pipeline_result.response:
return self._pipeline_response_formatter.handle_pipeline_response(
input_pipeline_result.response, streaming
Expand All @@ -93,7 +150,6 @@ async def complete(
model_response = await self._completion_handler.execute_completion(
provider_request, api_key=api_key, stream=streaming
)

if not streaming:
return self._output_normalizer.denormalize(model_response)
return self._output_normalizer.denormalize_streaming(model_response)
Expand Down
10 changes: 10 additions & 0 deletions src/codegate/providers/llamacpp/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
"""
Normalize the input data
"""
# When doing FIM, we receive "prompt" instead of messages. Normalizing.
if "prompt" in data:
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
data["had_prompt_before"] = True
try:
return ChatCompletionRequest(**data)
except Exception as e:
Expand All @@ -19,6 +24,11 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
"""
Denormalize the input data
"""
# If we receive "prompt" in FIM, we need convert it back.
if data.get("had_prompt_before", False):
data["prompt"] = data["messages"][0]["content"]
del data["had_prompt_before"]
del data["messages"]
return data


Expand Down
13 changes: 10 additions & 3 deletions src/codegate/providers/llamacpp/provider.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import json
from typing import Optional

from fastapi import Request

from codegate.providers.base import BaseProvider
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer


class LlamaCppProvider(BaseProvider):
def __init__(self, pipeline_processor=None):
def __init__(
self,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
):
completion_handler = LlamaCppCompletionHandler()
super().__init__(
LLamaCppInputNormalizer(),
LLamaCppOutputNormalizer(),
completion_handler,
pipeline_processor,
fim_pipeline_processor
)

@property
Expand All @@ -34,5 +40,6 @@ async def create_completion(
body = await request.body()
data = json.loads(body)

stream = await self.complete(data, api_key=None)
is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, None, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
14 changes: 11 additions & 3 deletions src/codegate/providers/openai/provider.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import json
from typing import Optional

from fastapi import Header, HTTPException, Request

from codegate.providers.base import BaseProvider
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer


class OpenAIProvider(BaseProvider):
def __init__(self, pipeline_processor=None):
def __init__(
self,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
):
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
super().__init__(
OpenAIInputNormalizer(),
OpenAIOutputNormalizer(),
completion_handler,
pipeline_processor,
fim_pipeline_processor
)

@property
Expand All @@ -29,6 +35,7 @@ def _setup_routes(self):
"""

@self.router.post(f"/{self.provider_route_name}/chat/completions")
@self.router.post(f"/{self.provider_route_name}/completions")
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
Expand All @@ -40,5 +47,6 @@ async def create_completion(
body = await request.body()
data = json.loads(body)

stream = await self.complete(data, api_key)
is_fim_request = self._is_fim_request(request, data)
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
21 changes: 17 additions & 4 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,28 @@ def init_app() -> FastAPI:
steps: List[PipelineStep] = [
CodegateVersion(),
]

# Leaving the pipeline empty for now
fim_steps: List[PipelineStep] = [
]
pipeline = SequentialPipelineProcessor(steps)
fim_pipeline = SequentialPipelineProcessor(fim_steps)

# Create provider registry
registry = ProviderRegistry(app)

# Register all known providers
registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline))
registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline))
registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline))
registry.add_provider("openai", OpenAIProvider(
pipeline_processor=pipeline,
fim_pipeline_processor=fim_pipeline
))
registry.add_provider("anthropic", AnthropicProvider(
pipeline_processor=pipeline,
fim_pipeline_processor=fim_pipeline
))
registry.add_provider("llamacpp", LlamaCppProvider(
pipeline_processor=pipeline,
fim_pipeline_processor=fim_pipeline
))

# Create and add system routes
system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs
Expand Down
Loading

0 comments on commit 189aee9

Please sign in to comment.