Skip to content

Commit

Permalink
refactor: streamline entity completion args, remove config
Browse files Browse the repository at this point in the history
  • Loading branch information
lxobr committed Feb 20, 2025
1 parent 85aaee9 commit 8ac63a8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 66 deletions.
92 changes: 38 additions & 54 deletions cognee/tasks/entity_completion/entity_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,34 @@

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.tasks.entity_completion.entity_completion_config import EntityCompletionConfig
from cognee.tasks.entity_completion.entity_extractors.entity_extractor_adapters import (
EntityExtractorAdapter,
from cognee.tasks.entity_completion.entity_extractors.base_entity_extractor import (
BaseEntityExtractor,
)
from cognee.tasks.entity_completion.context_providers.context_provider_adapters import (
ContextProviderAdapter,
from cognee.tasks.entity_completion.context_providers.base_context_provider import (
BaseContextProvider,
)

logger = logging.getLogger("entity_completion")

entity_completion_config = EntityCompletionConfig()
# Default prompt template paths
DEFAULT_SYSTEM_PROMPT_TEMPLATE = "answer_simple_question.txt"
DEFAULT_USER_PROMPT_TEMPLATE = "context_for_question.txt"


async def get_llm_response(query: str, context: str) -> str:
async def get_llm_response(
query: str,
context: str,
system_prompt_template: str = None,
user_prompt_template: str = None,
) -> str:
"""Generate LLM response based on query and context."""
try:
args = {
"question": query,
"context": context,
}
user_prompt = render_prompt(entity_completion_config.user_prompt_template, args)
system_prompt = read_query_prompt(entity_completion_config.system_prompt_template)
user_prompt = render_prompt(user_prompt_template or DEFAULT_USER_PROMPT_TEMPLATE, args)
system_prompt = read_query_prompt(system_prompt_template or DEFAULT_SYSTEM_PROMPT_TEMPLATE)

llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
Expand All @@ -37,68 +43,35 @@ async def get_llm_response(query: str, context: str) -> str:
raise


def _get_entity_extractor(
extractor: Union[str, EntityExtractorAdapter, None],
) -> EntityExtractorAdapter:
"""Get entity extractor adapter from string or enum, using config default if None."""
if extractor is None:
extractor = entity_completion_config.entity_extractor

if isinstance(extractor, str):
try:
return EntityExtractorAdapter(extractor)
except ValueError:
raise ValueError(f"Unsupported entity extractor: {extractor}")
return extractor


def _get_context_provider(
getter: Union[str, ContextProviderAdapter, None],
) -> ContextProviderAdapter:
"""Get context provider adapter from string or enum, using config default if None."""
if getter is None:
getter = entity_completion_config.context_getter

if isinstance(getter, str):
try:
return ContextProviderAdapter(getter)
except ValueError:
raise ValueError(f"Unsupported context provider: {getter}")
return getter


async def entity_completion(
query: str,
extractor_type: Union[str, EntityExtractorAdapter] = None,
context_provider_type: Union[str, ContextProviderAdapter] = None,
extractor: BaseEntityExtractor,
context_provider: BaseContextProvider,
system_prompt_template: str = None,
user_prompt_template: str = None,
) -> List[str]:
"""Execute entity-based completion using configurable components."""
"""Execute entity-based completion using provided components."""
if not query or not isinstance(query, str):
logger.error("Invalid query type or empty query")
return ["Invalid query input"]

try:
extractor_adapter = _get_entity_extractor(extractor_type)
context_provider_adapter = _get_context_provider(context_provider_type)

logger.info(f"Processing query: {query[:100]}")

extractor_instance = extractor_adapter.adapter_class()
entities = await extractor_instance.extract_entities(query)
logger.debug(f"Extracted entities: {[e.name for e in entities]}")

entities = await extractor.extract_entities(query)
if not entities:
logger.info("No entities extracted")
return ["No entities found"]

context_provider_instance = context_provider_adapter.adapter_class()
context = await context_provider_instance.get_context(entities, query)

context = await context_provider.get_context(entities, query)
if not context:
logger.info("No context retrieved")
return ["No context found"]

return [await get_llm_response(query, context)]
response = await get_llm_response(
query, context, system_prompt_template, user_prompt_template
)
return [response]

except Exception as e:
logger.error(f"Entity completion failed: {str(e)}")
Expand All @@ -108,12 +81,23 @@ async def entity_completion(
if __name__ == "__main__":
# For testing purposes, will be removed by the end of the sprint
import asyncio
import logging
from cognee.tasks.entity_completion.entity_extractors.entity_extractor_adapters import (
EntityExtractorAdapter,
)
from cognee.tasks.entity_completion.context_providers.context_provider_adapters import (
ContextProviderAdapter,
)

logging.basicConfig(level=logging.INFO)

async def run_entity_completion():
# Uses config defaults
result = await entity_completion("Tell me about Einstein")
result = await entity_completion(
"Tell me about Einstein",
EntityExtractorAdapter.DUMMY.adapter_class(),
ContextProviderAdapter.DUMMY.adapter_class(),
)
print(f"Query Response: {result[0]}")

asyncio.run(run_entity_completion())
12 changes: 0 additions & 12 deletions cognee/tasks/entity_completion/entity_completion_config.py

This file was deleted.

0 comments on commit 8ac63a8

Please sign in to comment.