diff --git a/pyproject.toml b/pyproject.toml index 68273ae..ae53edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,7 @@ exclude_lines = [ [[tool.mypy.overrides]] module = [ "simplemma.*", - "litellm.*" + "litellm.*", + "distutils.*", ] ignore_missing_imports = true \ No newline at end of file diff --git a/src/banks/prompt.py b/src/banks/prompt.py index 5aff48e..3930ac1 100644 --- a/src/banks/prompt.py +++ b/src/banks/prompt.py @@ -7,22 +7,24 @@ from banks.errors import AsyncError -class Prompt: +class BasePrompt: def __init__(self, text: str) -> None: self._template = env.from_string(text) @classmethod - def from_template(cls, name: str) -> "Prompt": + def from_template(cls, name: str) -> "BasePrompt": p = cls("") p._template = env.get_template(name) return p + +class Prompt(BasePrompt): def text(self, data: Optional[dict] = None) -> str: data = data or {} return self._template.render(data) -class AsyncPrompt(Prompt): +class AsyncPrompt(BasePrompt): def __init__(self, text: str) -> None: super().__init__(text) @@ -32,4 +34,5 @@ def __init__(self, text: str) -> None: async def text(self, data: Optional[dict] = None) -> str: data = data or {} - return await self._template.render_async(data) + result: str = await self._template.render_async(data) + return result