From 34ca65c618c8ceed0b4c5af40b16964d45dbfe98 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Sat, 30 Nov 2024 22:58:16 +0000 Subject: [PATCH] Ollama provider Ollama provider, tested against various models using chat and FIM Closes: #70 --- docs/cli.md | 10 ++ docs/configuration.md | 11 +- docs/development.md | 15 +- src/codegate/config.py | 1 + .../extract_snippets/extract_snippets.py | 1 + src/codegate/providers/__init__.py | 2 + src/codegate/providers/ollama/__init__.py | 3 + src/codegate/providers/ollama/adapter.py | 86 ++++++++++ src/codegate/providers/ollama/provider.py | 161 ++++++++++++++++++ src/codegate/server.py | 4 + tests/providers/ollama/test_ollama_adapter.py | 154 +++++++++++++++++ .../providers/ollama/test_ollama_provider.py | 157 +++++++++++++++++ 12 files changed, 599 insertions(+), 6 deletions(-) create mode 100644 src/codegate/providers/ollama/__init__.py create mode 100644 src/codegate/providers/ollama/adapter.py create mode 100644 src/codegate/providers/ollama/provider.py create mode 100644 tests/providers/ollama/test_ollama_adapter.py create mode 100644 tests/providers/ollama/test_ollama_provider.py diff --git a/docs/cli.md b/docs/cli.md index 9142d405..58f85f96 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -61,6 +61,11 @@ codegate serve [OPTIONS] - Base URL for Anthropic provider - Overrides configuration file and environment variables +- `--ollama-url TEXT`: Ollama provider URL (default: http://localhost:11434) + - Optional + - Base URL for Ollama provider (/api path is added automatically) + - Overrides configuration file and environment variables + ### show-prompts Display the loaded system prompts: @@ -120,6 +125,11 @@ Start server with custom vLLM endpoint: codegate serve --vllm-url https://vllm.example.com ``` +Start server with custom Ollama endpoint: +```bash +codegate serve --ollama-url http://localhost:11434 +``` + Show default system prompts: ```bash codegate show-prompts diff --git a/docs/configuration.md b/docs/configuration.md index 6a015d28..e4971fdc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -20,6 +20,7 @@ The configuration system in Codegate is managed through the `Config` class in `c - vLLM: "http://localhost:8000" - OpenAI: "https://api.openai.com/v1" - Anthropic: "https://api.anthropic.com/v1" + - Ollama: "http://localhost:11434" ## Configuration Methods @@ -41,6 +42,7 @@ provider_urls: vllm: "https://vllm.example.com" openai: "https://api.openai.com/v1" anthropic: "https://api.anthropic.com/v1" + ollama: "http://localhost:11434" ``` ### From Environment Variables @@ -55,6 +57,7 @@ Environment variables are automatically loaded with these mappings: - `CODEGATE_PROVIDER_VLLM_URL`: vLLM provider URL - `CODEGATE_PROVIDER_OPENAI_URL`: OpenAI provider URL - `CODEGATE_PROVIDER_ANTHROPIC_URL`: Anthropic provider URL +- `CODEGATE_PROVIDER_OLLAMA_URL`: Ollama provider URL ```python config = Config.from_env() @@ -72,6 +75,7 @@ Provider URLs can be configured in several ways: vllm: "https://vllm.example.com" # /v1 path is added automatically openai: "https://api.openai.com/v1" anthropic: "https://api.anthropic.com/v1" + ollama: "http://localhost:11434" # /api path is added automatically ``` 2. Via Environment Variables: @@ -79,14 +83,17 @@ Provider URLs can be configured in several ways: export CODEGATE_PROVIDER_VLLM_URL=https://vllm.example.com export CODEGATE_PROVIDER_OPENAI_URL=https://api.openai.com/v1 export CODEGATE_PROVIDER_ANTHROPIC_URL=https://api.anthropic.com/v1 + export CODEGATE_PROVIDER_OLLAMA_URL=http://localhost:11434 ``` 3. Via CLI Flags: ```bash - codegate serve --vllm-url https://vllm.example.com + codegate serve --vllm-url https://vllm.example.com --ollama-url http://localhost:11434 ``` -Note: For the vLLM provider, the /v1 path is automatically appended to the base URL if not present. +Note: +- For the vLLM provider, the /v1 path is automatically appended to the base URL if not present. +- For the Ollama provider, the /api path is automatically appended to the base URL if not present. ### Log Levels diff --git a/docs/development.md b/docs/development.md index 430cd87d..791b54a1 100644 --- a/docs/development.md +++ b/docs/development.md @@ -1,5 +1,3 @@ -# Development Guide - This guide provides comprehensive information for developers working on the Codegate project. ## Project Overview @@ -157,6 +155,13 @@ Codegate supports multiple AI providers through a modular provider system. - Default URL: https://api.anthropic.com/v1 - Anthropic Claude API implementation +4. **Ollama Provider** + - Default URL: http://localhost:11434 + - Endpoints: + * Native Ollama API: `/ollama/api/chat` + * OpenAI-compatible: `/ollama/chat/completions` + ``` + ### Configuring Providers Provider URLs can be configured through: @@ -167,6 +172,7 @@ Provider URLs can be configured through: vllm: "https://vllm.example.com" openai: "https://api.openai.com/v1" anthropic: "https://api.anthropic.com/v1" + ollama: "http://localhost:11434" # /api path added automatically ``` 2. Environment variables: @@ -174,11 +180,12 @@ Provider URLs can be configured through: export CODEGATE_PROVIDER_VLLM_URL=https://vllm.example.com export CODEGATE_PROVIDER_OPENAI_URL=https://api.openai.com/v1 export CODEGATE_PROVIDER_ANTHROPIC_URL=https://api.anthropic.com/v1 + export CODEGATE_PROVIDER_OLLAMA_URL=http://localhost:11434 ``` 3. CLI flags: ```bash - codegate serve --vllm-url https://vllm.example.com + codegate serve --vllm-url https://vllm.example.com --ollama-url http://localhost:11434 ``` ### Implementing New Providers @@ -276,4 +283,4 @@ codegate serve --prompts my-prompts.yaml codegate serve --vllm-url https://vllm.example.com ``` -See [CLI Documentation](cli.md) for detailed command information. +See [CLI Documentation](cli.md) for detailed command information. \ No newline at end of file diff --git a/src/codegate/config.py b/src/codegate/config.py index c88fc4d0..5eb57326 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -19,6 +19,7 @@ "openai": "https://api.openai.com/v1", "anthropic": "https://api.anthropic.com/v1", "vllm": "http://localhost:8000", # Base URL without /v1 path + "ollama": "http://localhost:11434", # Default Ollama server URL } diff --git a/src/codegate/pipeline/extract_snippets/extract_snippets.py b/src/codegate/pipeline/extract_snippets/extract_snippets.py index a50460ee..bee756e0 100644 --- a/src/codegate/pipeline/extract_snippets/extract_snippets.py +++ b/src/codegate/pipeline/extract_snippets/extract_snippets.py @@ -13,6 +13,7 @@ logger = structlog.get_logger("codegate") + def ecosystem_from_filepath(filepath: str) -> Optional[str]: """ Determine language from filepath. diff --git a/src/codegate/providers/__init__.py b/src/codegate/providers/__init__.py index 1ec69183..d8a35e8e 100644 --- a/src/codegate/providers/__init__.py +++ b/src/codegate/providers/__init__.py @@ -1,5 +1,6 @@ from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.base import BaseProvider +from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider from codegate.providers.registry import ProviderRegistry from codegate.providers.vllm.provider import VLLMProvider @@ -10,4 +11,5 @@ "OpenAIProvider", "AnthropicProvider", "VLLMProvider", + "OllamaProvider", ] diff --git a/src/codegate/providers/ollama/__init__.py b/src/codegate/providers/ollama/__init__.py new file mode 100644 index 00000000..1f35008c --- /dev/null +++ b/src/codegate/providers/ollama/__init__.py @@ -0,0 +1,3 @@ +from codegate.providers.ollama.provider import OllamaProvider + +__all__ = ["OllamaProvider"] diff --git a/src/codegate/providers/ollama/adapter.py b/src/codegate/providers/ollama/adapter.py new file mode 100644 index 00000000..afbd5a0a --- /dev/null +++ b/src/codegate/providers/ollama/adapter.py @@ -0,0 +1,86 @@ +from typing import Any, Dict + +from litellm import ChatCompletionRequest + +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer + + +class OllamaInputNormalizer(ModelInputNormalizer): + def __init__(self): + super().__init__() + + def normalize(self, data: Dict) -> ChatCompletionRequest: + """ + Normalize the input data to the format expected by Ollama. + """ + # Make a copy of the data to avoid modifying the original + normalized_data = data.copy() + + # Format the model name + if "model" in normalized_data: + normalized_data["model"] = normalized_data["model"].strip() + + # Convert messages format if needed + if "messages" in normalized_data: + messages = normalized_data["messages"] + converted_messages = [] + for msg in messages: + if isinstance(msg.get("content"), list): + # Convert list format to string + content_parts = [] + for part in msg["content"]: + if part.get("type") == "text": + content_parts.append(part["text"]) + msg = msg.copy() + msg["content"] = " ".join(content_parts) + converted_messages.append(msg) + normalized_data["messages"] = converted_messages + + # Ensure the base_url ends with /api if provided + if "base_url" in normalized_data: + base_url = normalized_data["base_url"].rstrip("/") + if not base_url.endswith("/api"): + normalized_data["base_url"] = f"{base_url}/api" + + return ChatCompletionRequest(**normalized_data) + + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """ + Convert back to raw format for the API request + """ + return data + + +class OllamaOutputNormalizer(ModelOutputNormalizer): + def __init__(self): + super().__init__() + + def normalize_streaming( + self, + model_reply: Any, + ) -> Any: + """ + Pass through Ollama response + """ + return model_reply + + def normalize(self, model_reply: Any) -> Any: + """ + Pass through Ollama response + """ + return model_reply + + def denormalize(self, normalized_reply: Any) -> Any: + """ + Pass through Ollama response + """ + return normalized_reply + + def denormalize_streaming( + self, + normalized_reply: Any, + ) -> Any: + """ + Pass through Ollama response + """ + return normalized_reply diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py new file mode 100644 index 00000000..762fffc2 --- /dev/null +++ b/src/codegate/providers/ollama/provider.py @@ -0,0 +1,161 @@ +import asyncio +import json +from typing import Optional + +import httpx +from fastapi import Header, HTTPException, Request +from fastapi.responses import StreamingResponse + +from codegate.config import Config +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor +from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator +from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer + + +async def stream_ollama_response(client: httpx.AsyncClient, url: str, data: dict): + """Stream response directly from Ollama API.""" + try: + async with client.stream("POST", url, json=data, timeout=30.0) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.strip(): + try: + # Parse the response to ensure it's valid JSON + response_data = json.loads(line) + # Add newline to ensure proper streaming + yield line.encode("utf-8") + b"\n" + # If this is the final response, break + if response_data.get("done", False): + break + # Small delay to prevent overwhelming the client + await asyncio.sleep(0.01) + except json.JSONDecodeError: + yield json.dumps({"error": "Invalid JSON response"}).encode("utf-8") + b"\n" + break + except Exception as e: + yield json.dumps({"error": str(e)}).encode("utf-8") + b"\n" + break + except Exception as e: + yield json.dumps({"error": f"Stream error: {str(e)}"}).encode("utf-8") + b"\n" + + +class OllamaProvider(BaseProvider): + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + ): + completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) + super().__init__( + OllamaInputNormalizer(), + OllamaOutputNormalizer(), + completion_handler, + pipeline_processor, + fim_pipeline_processor, + ) + self.client = httpx.AsyncClient(timeout=30.0) + + @property + def provider_route_name(self) -> str: + return "ollama" + + def _setup_routes(self): + """ + Sets up Ollama API routes. + """ + + # Native Ollama API routes + @self.router.post(f"/{self.provider_route_name}/api/chat") + async def ollama_chat( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + _api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + # Get the Ollama base URL + config = Config.get_config() + base_url = config.provider_urls.get("ollama", "http://localhost:11434") + + # Convert chat format to Ollama generate format + messages = [] + for msg in data.get("messages", []): + role = msg.get("role", "") + content = msg.get("content", "") + if isinstance(content, list): + # Handle list-based content format + content = " ".join( + part["text"] for part in content if part.get("type") == "text" + ) + messages.append({"role": role, "content": content}) + + ollama_data = { + "model": data.get("model", "").strip(), + "messages": messages, + "stream": True, + "options": data.get("options", {}), + } + + # Stream response directly from Ollama + return StreamingResponse( + stream_ollama_response(self.client, f"{base_url}/api/chat", ollama_data), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + @self.router.post(f"/{self.provider_route_name}/api/generate") + async def ollama_generate( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + _api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + # Get the Ollama base URL + config = Config.get_config() + base_url = config.provider_urls.get("ollama", "http://localhost:11434") + + # Prepare generate request + ollama_data = { + "model": data.get("model", "").strip(), + "prompt": data.get("prompt", ""), + "stream": True, + "options": data.get("options", {}), + } + + # Add any context or system prompt if provided + if "context" in data: + ollama_data["context"] = data["context"] + if "system" in data: + ollama_data["system"] = data["system"] + + # Stream response directly from Ollama + return StreamingResponse( + stream_ollama_response(self.client, f"{base_url}/api/generate", ollama_data), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + # OpenAI-compatible routes for backward compatibility + @self.router.post(f"/{self.provider_route_name}/chat/completions") + @self.router.post(f"/{self.provider_route_name}/completions") + async def create_completion( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + # Redirect to native Ollama endpoint + return await ollama_chat(request, authorization) diff --git a/src/codegate/server.py b/src/codegate/server.py index 631824bf..05da192c 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -10,6 +10,7 @@ from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider +from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider from codegate.providers.registry import ProviderRegistry from codegate.providers.vllm.provider import VLLMProvider @@ -54,6 +55,9 @@ def init_app() -> FastAPI: registry.add_provider( "vllm", VLLMProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) ) + registry.add_provider( + "ollama", OllamaProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + ) # Create and add system routes system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs diff --git a/tests/providers/ollama/test_ollama_adapter.py b/tests/providers/ollama/test_ollama_adapter.py new file mode 100644 index 00000000..7c3728c1 --- /dev/null +++ b/tests/providers/ollama/test_ollama_adapter.py @@ -0,0 +1,154 @@ +"""Tests for Ollama adapter.""" + +from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer + + +def test_normalize_ollama_input(): + """Test input normalization for Ollama.""" + normalizer = OllamaInputNormalizer() + + # Test model name handling + data = {"model": "llama2"} + normalized = normalizer.normalize(data) + assert type(normalized) == dict # noqa: E721 + assert normalized["model"] == "llama2" # No prefix needed for Ollama + + # Test model name with spaces + data = {"model": "codellama:7b-instruct "} # Extra space + normalized = normalizer.normalize(data) + assert normalized["model"] == "codellama:7b-instruct" # Space removed + + # Test base URL handling + data = {"model": "llama2", "base_url": "http://localhost:11434"} + normalized = normalizer.normalize(data) + assert normalized["base_url"] == "http://localhost:11434/api" + + # Test base URL already has /api + data = {"model": "llama2", "base_url": "http://localhost:11434/api"} + normalized = normalizer.normalize(data) + assert normalized["base_url"] == "http://localhost:11434/api" + + # Test base URL with trailing slash + data = {"model": "llama2", "base_url": "http://localhost:11434/"} + normalized = normalizer.normalize(data) + assert normalized["base_url"] == "http://localhost:11434/api" + + +def test_normalize_native_ollama_input(): + """Test input normalization for native Ollama API requests.""" + normalizer = OllamaInputNormalizer() + + # Test native Ollama request format + data = { + "model": "codellama:7b-instruct", + "messages": [{"role": "user", "content": "Hello"}], + "options": {"num_ctx": 8096, "num_predict": 6}, + } + normalized = normalizer.normalize(data) + assert type(normalized) == dict # noqa: E721 + assert normalized["model"] == "codellama:7b-instruct" + assert "options" in normalized + assert normalized["options"]["num_ctx"] == 8096 + + # Test native Ollama request with base URL + data = { + "model": "codellama:7b-instruct", + "messages": [{"role": "user", "content": "Hello"}], + "options": {"num_ctx": 8096, "num_predict": 6}, + "base_url": "http://localhost:11434", + } + normalized = normalizer.normalize(data) + assert normalized["base_url"] == "http://localhost:11434/api" + + +def test_normalize_ollama_message_format(): + """Test normalization of Ollama message formats.""" + normalizer = OllamaInputNormalizer() + + # Test list-based content format + data = { + "model": "codellama:7b-instruct", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}], + } + ], + } + normalized = normalizer.normalize(data) + assert normalized["messages"][0]["content"] == "Hello world" + + # Test mixed content format + data = { + "model": "codellama:7b-instruct", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "other", "text": "ignored"}, + {"type": "text", "text": "world"}, + ], + } + ], + } + normalized = normalizer.normalize(data) + assert normalized["messages"][0]["content"] == "Hello world" + + +def test_normalize_ollama_generate_format(): + """Test normalization of Ollama generate format.""" + normalizer = OllamaInputNormalizer() + + # Test basic generate request + data = { + "model": "codellama:7b-instruct", + "prompt": "def hello_world", + "options": {"temperature": 0.7}, + } + normalized = normalizer.normalize(data) + assert normalized["model"] == "codellama:7b-instruct" + assert normalized["prompt"] == "def hello_world" + assert normalized["options"]["temperature"] == 0.7 + + # Test generate request with context + data = { + "model": "codellama:7b-instruct", + "prompt": "def hello_world", + "context": [1, 2, 3], + "system": "You are a helpful assistant", + "options": {"temperature": 0.7}, + } + normalized = normalizer.normalize(data) + assert normalized["context"] == [1, 2, 3] + assert normalized["system"] == "You are a helpful assistant" + + +def test_normalize_ollama_output(): + """Test output normalization for Ollama.""" + normalizer = OllamaOutputNormalizer() + + # Test streaming response passthrough + response = {"message": {"role": "assistant", "content": "test"}} + normalized = normalizer.normalize_streaming(response) + assert normalized == response + + # Test regular response passthrough + response = {"message": {"role": "assistant", "content": "test"}} + normalized = normalizer.normalize(response) + assert normalized == response + + # Test generate response passthrough + response = {"response": "def hello_world():", "done": False} + normalized = normalizer.normalize(response) + assert normalized == response + + # Test denormalize passthrough + response = {"message": {"role": "assistant", "content": "test"}} + denormalized = normalizer.denormalize(response) + assert denormalized == response + + # Test streaming denormalize passthrough + response = {"message": {"role": "assistant", "content": "test"}} + denormalized = normalizer.denormalize_streaming(response) + assert denormalized == response diff --git a/tests/providers/ollama/test_ollama_provider.py b/tests/providers/ollama/test_ollama_provider.py new file mode 100644 index 00000000..49bff909 --- /dev/null +++ b/tests/providers/ollama/test_ollama_provider.py @@ -0,0 +1,157 @@ +"""Tests for Ollama provider.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from codegate.providers.ollama.provider import OllamaProvider + + +class MockConfig: + def __init__(self): + self.provider_urls = {"ollama": "http://localhost:11434"} + + +@pytest.fixture +def app(): + """Create FastAPI app with Ollama provider.""" + app = FastAPI() + provider = OllamaProvider() + app.include_router(provider.get_routes()) + return app + + +@pytest.fixture +def test_client(app): + """Create test client.""" + return TestClient(app) + + +async def async_iter(items): + """Helper to create async iterator.""" + for item in items: + yield item + + +@patch("codegate.config.Config.get_config", return_value=MockConfig()) +def test_ollama_chat(mock_config, test_client): + """Test chat endpoint.""" + data = { + "model": "codellama:7b-instruct", + "messages": [{"role": "user", "content": "Hello"}], + "options": {"temperature": 0.7}, + } + + with patch("httpx.AsyncClient.stream") as mock_stream: + # Mock the streaming response + mock_response = AsyncMock() + mock_response.raise_for_status = AsyncMock() + mock_response.aiter_lines = AsyncMock( + return_value=async_iter( + [ + '{"response": "Hello!", "done": false}', + '{"response": " How can I help?", "done": true}', + ] + ) + ) + mock_stream.return_value.__aenter__.return_value = mock_response + + response = test_client.post( + "/ollama/api/chat", json=data, headers={"Authorization": "Bearer test-key"} + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/x-ndjson" + + # Verify the request to Ollama + mock_stream.assert_called_once() + call_args = mock_stream.call_args + assert call_args[0][0] == "POST" + assert call_args[0][1].endswith("/api/chat") + sent_data = call_args[1]["json"] + assert sent_data["model"] == "codellama:7b-instruct" + assert sent_data["messages"] == data["messages"] + assert sent_data["options"] == data["options"] + assert sent_data["stream"] is True + + +@patch("codegate.config.Config.get_config", return_value=MockConfig()) +def test_ollama_generate(mock_config, test_client): + """Test generate endpoint.""" + data = { + "model": "codellama:7b-instruct", + "prompt": "def hello_world", + "options": {"temperature": 0.7}, + "context": [1, 2, 3], + "system": "You are a helpful assistant", + } + + with patch("httpx.AsyncClient.stream") as mock_stream: + # Mock the streaming response + mock_response = AsyncMock() + mock_response.raise_for_status = AsyncMock() + mock_response.aiter_lines = AsyncMock( + return_value=async_iter( + [ + '{"response": "():\\n", "done": false}', + '{"response": " print(\\"Hello, World!\\")", "done": true}', + ] + ) + ) + mock_stream.return_value.__aenter__.return_value = mock_response + + response = test_client.post( + "/ollama/api/generate", json=data, headers={"Authorization": "Bearer test-key"} + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/x-ndjson" + + # Verify the request to Ollama + mock_stream.assert_called_once() + call_args = mock_stream.call_args + assert call_args[0][0] == "POST" + assert call_args[0][1].endswith("/api/generate") + sent_data = call_args[1]["json"] + assert sent_data["model"] == "codellama:7b-instruct" + assert sent_data["prompt"] == "def hello_world" + assert sent_data["options"] == data["options"] + assert sent_data["context"] == data["context"] + assert sent_data["system"] == data["system"] + assert sent_data["stream"] is True + + +@patch("codegate.config.Config.get_config", return_value=MockConfig()) +def test_ollama_error_handling(mock_config, test_client): + """Test error handling.""" + data = {"model": "invalid-model"} + + with patch("httpx.AsyncClient.stream") as mock_stream: + # Mock an error response + mock_stream.side_effect = Exception("Model not found") + + response = test_client.post( + "/ollama/api/generate", json=data, headers={"Authorization": "Bearer test-key"} + ) + + assert response.status_code == 200 # Errors are returned in the stream + content = response.content.decode().strip() + assert "error" in content + assert "Model not found" in content + + +def test_ollama_auth_required(test_client): + """Test authentication requirement.""" + data = {"model": "codellama:7b-instruct"} + + # Test without auth header + response = test_client.post("/ollama/api/generate", json=data) + assert response.status_code == 422 + + # Test with invalid auth header + response = test_client.post( + "/ollama/api/generate", json=data, headers={"Authorization": "Invalid"} + ) + assert response.status_code == 401