-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrates LiteLLM for Unified Access to Multiple LLM Models #5925
Changes from 10 commits
c3f0327
e68ab40
30b45bf
bcd37bc
22a99e8
6226613
954fd2f
65acf0c
2360af5
411e16a
2d597b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,174 @@ | ||
# Python imports | ||
import requests | ||
# Python import | ||
import os | ||
from typing import List, Dict, Tuple | ||
|
||
# Third party import | ||
import litellm | ||
import requests | ||
|
||
# Third party imports | ||
from openai import OpenAI | ||
from rest_framework.response import Response | ||
from rest_framework import status | ||
from rest_framework.response import Response | ||
|
||
# Django imports | ||
# Module import | ||
from plane.app.permissions import ROLE, allow_permission | ||
from plane.app.serializers import (ProjectLiteSerializer, | ||
WorkspaceLiteSerializer) | ||
from plane.db.models import Project, Workspace | ||
from plane.license.utils.instance_value import get_configuration_value | ||
from plane.utils.exception_logger import log_exception | ||
|
||
# Module imports | ||
from ..base import BaseAPIView | ||
from plane.app.permissions import allow_permission, ROLE | ||
from plane.db.models import Workspace, Project | ||
from plane.app.serializers import ProjectLiteSerializer, WorkspaceLiteSerializer | ||
from plane.license.utils.instance_value import get_configuration_value | ||
|
||
|
||
class LLMProvider: | ||
"""Base class for LLM provider configurations""" | ||
name: str = "" | ||
models: List[str] = [] | ||
api_key_env: str = "" | ||
default_model: str = "" | ||
|
||
@classmethod | ||
def get_config(cls) -> Dict[str, str | List[str]]: | ||
return { | ||
"name": cls.name, | ||
"models": cls.models, | ||
"default_model": cls.default_model, | ||
} | ||
|
||
class OpenAIProvider(LLMProvider): | ||
name = "OpenAI" | ||
models = ["gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o", "o1-mini", "o1-preview"] | ||
api_key_env = "OPENAI_API_KEY" | ||
default_model = "gpt-4o-mini" | ||
|
||
class AnthropicProvider(LLMProvider): | ||
name = "Anthropic" | ||
models = [ | ||
"claude-3-5-sonnet-20240620", | ||
"claude-3-haiku-20240307", | ||
"claude-3-opus-20240229", | ||
"claude-3-sonnet-20240229", | ||
"claude-2.1", | ||
"claude-2", | ||
"claude-instant-1.2", | ||
"claude-instant-1" | ||
] | ||
api_key_env = "ANTHROPIC_API_KEY" | ||
default_model = "claude-3-sonnet-20240229" | ||
|
||
class GeminiProvider(LLMProvider): | ||
name = "Gemini" | ||
models = ["gemini-pro", "gemini-1.5-pro-latest", "gemini-pro-vision"] | ||
api_key_env = "GEMINI_API_KEY" | ||
default_model = "gemini-pro" | ||
|
||
SUPPORTED_PROVIDERS = { | ||
"openai": OpenAIProvider, | ||
"anthropic": AnthropicProvider, | ||
"gemini": GeminiProvider, | ||
} | ||
|
||
def get_llm_config() -> Tuple[str | None, str | None, str | None]: | ||
""" | ||
Helper to get LLM configuration values, returns: | ||
- api_key, model, provider | ||
""" | ||
provider_key, model = get_configuration_value([ | ||
{ | ||
"key": "LLM_PROVIDER", | ||
"default": os.environ.get("LLM_PROVIDER", "openai"), | ||
}, | ||
{ | ||
"key": "LLM_MODEL", | ||
"default": None, | ||
}, | ||
]) | ||
|
||
provider = SUPPORTED_PROVIDERS.get(provider_key.lower()) | ||
if not provider: | ||
log_exception(ValueError(f"Unsupported provider: {provider_key}")) | ||
return None, None, None | ||
|
||
api_key, _ = get_configuration_value([ | ||
{ | ||
"key": provider.api_key_env, | ||
"default": os.environ.get(provider.api_key_env, None), | ||
} | ||
]) | ||
|
||
if not api_key: | ||
log_exception(ValueError(f"Missing API key for provider: {provider.name}")) | ||
return None, None, None | ||
|
||
# If no model specified, use provider's default | ||
if not model: | ||
model = provider.default_model | ||
|
||
# Validate model is supported by provider | ||
if model not in provider.models: | ||
log_exception(ValueError( | ||
f"Model {model} not supported by {provider.name}. " | ||
f"Supported models: {', '.join(provider.models)}" | ||
)) | ||
return api_key, model, provider_key | ||
|
||
|
||
def get_llm_response(task, prompt, api_key: str, model: str, provider: str) -> Tuple[str | None, str | None]: | ||
"""Helper to get LLM completion response""" | ||
final_text = task + "\n" + prompt | ||
try: | ||
# For Gemini, prepend provider name to model | ||
if provider.lower() == "gemini": | ||
model = f"gemini/{model}" | ||
|
||
response = litellm.completion( | ||
model=model, | ||
messages=[{"role": "user", "content": final_text}], | ||
api_key=api_key, | ||
) | ||
text = response.choices[0].message.content.strip() | ||
return text, None | ||
except Exception as e: | ||
log_exception(e) | ||
error_type = e.__class__.__name__ | ||
if error_type == "AuthenticationError": | ||
return None, f"Invalid API key for {provider}" | ||
elif error_type == "RateLimitError": | ||
return None, f"Rate limit exceeded for {provider}" | ||
else: | ||
return None, f"Error occurred while generating response from {provider}" | ||
|
||
class GPTIntegrationEndpoint(BaseAPIView): | ||
@allow_permission([ROLE.ADMIN, ROLE.MEMBER]) | ||
def post(self, request, slug, project_id): | ||
OPENAI_API_KEY, GPT_ENGINE = get_configuration_value( | ||
[ | ||
{ | ||
"key": "OPENAI_API_KEY", | ||
"default": os.environ.get("OPENAI_API_KEY", None), | ||
}, | ||
{ | ||
"key": "GPT_ENGINE", | ||
"default": os.environ.get("GPT_ENGINE", "gpt-3.5-turbo"), | ||
}, | ||
] | ||
) | ||
api_key, model, provider = get_llm_config() | ||
|
||
# Get the configuration value | ||
# Check the keys | ||
if not OPENAI_API_KEY or not GPT_ENGINE: | ||
if not api_key or not model or not provider: | ||
return Response( | ||
{"error": "OpenAI API key and engine is required"}, | ||
{"error": "LLM provider API key and model are required"}, | ||
status=status.HTTP_400_BAD_REQUEST, | ||
) | ||
|
||
prompt = request.data.get("prompt", False) | ||
task = request.data.get("task", False) | ||
|
||
if not task: | ||
return Response( | ||
{"error": "Task is required"}, status=status.HTTP_400_BAD_REQUEST | ||
) | ||
|
||
final_text = task + "\n" + prompt | ||
|
||
client = OpenAI(api_key=OPENAI_API_KEY) | ||
|
||
response = client.chat.completions.create( | ||
model=GPT_ENGINE, messages=[{"role": "user", "content": final_text}] | ||
) | ||
text, error = get_llm_response(task, request.data.get("prompt", False), api_key, model, provider) | ||
if not text and error: | ||
return Response( | ||
{"error": "An internal error has occurred."}, | ||
status=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
) | ||
|
||
akash5100 marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+139
to
+158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve error handling in endpoints. Both endpoints use the same generic error message. Consider providing more specific error messages based on the error type. Apply this improvement to both endpoint handlers: if not text and error:
+ # Map error messages to user-friendly responses
+ error_mapping = {
+ "Authentication failed": ("Service configuration error", status.HTTP_503_SERVICE_UNAVAILABLE),
+ "Service temporarily unavailable": ("Please try again later", status.HTTP_429_TOO_MANY_REQUESTS),
+ }
+ error_msg, error_status = error_mapping.get(
+ error,
+ ("An unexpected error occurred", status.HTTP_500_INTERNAL_SERVER_ERROR)
+ )
return Response(
- {"error": "An internal error has occurred."},
- status=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ {"error": error_msg},
+ status=error_status,
) Also applies to: 177-196 🧰 Tools🪛 Ruff (0.8.2)153-153: Line too long (105 > 88) (E501) |
||
workspace = Workspace.objects.get(slug=slug) | ||
project = Project.objects.get(pk=project_id) | ||
|
||
text = response.choices[0].message.content.strip() | ||
text_html = text.replace("\n", "<br/>") | ||
return Response( | ||
{ | ||
"response": text, | ||
"response_html": text_html, | ||
"response_html": text.replace("\n", "<br/>"), | ||
"project_detail": ProjectLiteSerializer(project).data, | ||
"workspace_detail": WorkspaceLiteSerializer(workspace).data, | ||
}, | ||
|
@@ -76,47 +179,33 @@ def post(self, request, slug, project_id): | |
class WorkspaceGPTIntegrationEndpoint(BaseAPIView): | ||
@allow_permission(allowed_roles=[ROLE.ADMIN, ROLE.MEMBER], level="WORKSPACE") | ||
def post(self, request, slug): | ||
OPENAI_API_KEY, GPT_ENGINE = get_configuration_value( | ||
[ | ||
{ | ||
"key": "OPENAI_API_KEY", | ||
"default": os.environ.get("OPENAI_API_KEY", None), | ||
}, | ||
{ | ||
"key": "GPT_ENGINE", | ||
"default": os.environ.get("GPT_ENGINE", "gpt-3.5-turbo"), | ||
}, | ||
] | ||
) | ||
|
||
# Get the configuration value | ||
# Check the keys | ||
if not OPENAI_API_KEY or not GPT_ENGINE: | ||
api_key, model, provider = get_llm_config() | ||
|
||
if not api_key or not model or not provider: | ||
return Response( | ||
{"error": "OpenAI API key and engine is required"}, | ||
{"error": "LLM provider API key and model are required"}, | ||
status=status.HTTP_400_BAD_REQUEST, | ||
) | ||
|
||
prompt = request.data.get("prompt", False) | ||
task = request.data.get("task", False) | ||
|
||
if not task: | ||
return Response( | ||
{"error": "Task is required"}, status=status.HTTP_400_BAD_REQUEST | ||
) | ||
|
||
final_text = task + "\n" + prompt | ||
|
||
client = OpenAI(api_key=OPENAI_API_KEY) | ||
|
||
response = client.chat.completions.create( | ||
model=GPT_ENGINE, messages=[{"role": "user", "content": final_text}] | ||
) | ||
text, error = get_llm_response(task, request.data.get("prompt", False), api_key, model, provider) | ||
if not text and error: | ||
return Response( | ||
{"error": "An internal error has occurred."}, | ||
status=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
) | ||
|
||
text = response.choices[0].message.content.strip() | ||
text_html = text.replace("\n", "<br/>") | ||
return Response( | ||
{"response": text, "response_html": text_html}, status=status.HTTP_200_OK | ||
{ | ||
"response": text, | ||
"response_html": text.replace("\n", "<br/>"), | ||
}, | ||
status=status.HTTP_200_OK, | ||
|
||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enhance error message security.
While the error handling is good, the error messages could potentially leak sensitive information about the provider configuration.
Apply this diff to make error messages more generic:
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.8.2)
111-111: Line too long (109 > 88)
(E501)