Skip to content

Commit b9f17b1

Browse files
authored
Add LocalAI provider (#42)
* Add LocalAI provider
1 parent b0d6f8e commit b9f17b1

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

src/shelloracle/provider.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ def get_provider(name: str) -> type[Provider]:
3535
:param name: the provider name
3636
:return: the requested provider
3737
"""
38-
from .providers import Ollama, OpenAI
38+
from .providers import Ollama, OpenAI, LocalAI
3939
providers = {
4040
Ollama.name: Ollama,
41-
OpenAI.name: OpenAI
41+
OpenAI.name: OpenAI,
42+
LocalAI.name: LocalAI
4243
}
4344
return providers[name]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .ollama import Ollama
22
from .openai import OpenAI
3+
from .localai import LocalAI
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from collections.abc import AsyncIterator
2+
3+
from openai import APIError
4+
from openai import AsyncOpenAI as OpenAIClient
5+
6+
from ..config import Setting
7+
from ..provider import Provider, ProviderError
8+
9+
10+
class LocalAI(Provider):
11+
name = "LocalAI"
12+
13+
host = Setting(default="localhost")
14+
port = Setting(default=8080)
15+
model = Setting(default="mistral-openorca")
16+
system_prompt = Setting(
17+
default=(
18+
"Based on the following user description, generate a corresponding Bash command. Focus solely "
19+
"on interpreting the requirements and translating them into a single, executable Bash command. "
20+
"Ensure accuracy and relevance to the user's description. The output should be a valid Bash "
21+
"command that directly aligns with the user's intent, ready for execution in a command-line "
22+
"environment. Output nothing except for the command. No code block, no English explanation, "
23+
"no start/end tags."
24+
)
25+
)
26+
27+
@property
28+
def endpoint(self) -> str:
29+
return f"http://{self.host}:{self.port}"
30+
31+
def __init__(self):
32+
# Use a placeholder API key so the client will work
33+
self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint)
34+
35+
async def generate(self, prompt: str) -> AsyncIterator[str]:
36+
try:
37+
stream = await self.client.chat.completions.create(
38+
model=self.model,
39+
messages=[
40+
{"role": "system", "content": self.system_prompt},
41+
{"role": "user", "content": prompt}
42+
],
43+
stream=True,
44+
)
45+
async for chunk in stream:
46+
if chunk.choices[0].delta.content is not None:
47+
yield chunk.choices[0].delta.content
48+
except APIError as e:
49+
raise ProviderError(f"Something went wrong while querying LocalAI: {e}") from e

0 commit comments

Comments
 (0)