diff --git a/ollama_backend/ollama_backend.py b/ollama_backend/ollama_backend.py index 9875126..d4658dd 100644 --- a/ollama_backend/ollama_backend.py +++ b/ollama_backend/ollama_backend.py @@ -5,7 +5,6 @@ import subprocess from fastapi.responses import StreamingResponse, JSONResponse from fastapi import FastAPI, Request -from typing import Iterator, Union app = FastAPI() @@ -16,7 +15,7 @@ mount = modal.Mount.from_local_dir(local_path='modelfiles', remote_path="modelfiles") def serve_ollama(): - """Ensure Ollama server is running.""" + '''Ensure Ollama server is running.''' subprocess.Popen(["ollama", "serve"]) image = (modal.Image @@ -43,24 +42,23 @@ def serve_ollama(): ) class Ollama: '''Ollama class for handling calls to the endpoint''' - # model: str = modal.parameter(init=True) BASE_URL: str = "http://localhost:11434" @modal.enter() def init_model(self): - """Start the Ollama server.""" + '''Start the Ollama server.''' print("Starting server") serve_ollama() @modal.method() def warmup(self): - """Warmup the model.""" + '''Warmup the model.''' ollama.generate(self.model) return {"status": "ok"} @modal.method() async def chat(self, model: str, messages: list, tools: list = [], stream=True, **kwargs): - """Handle chat interaction and stream the response.""" + '''Handle chat interaction and stream the response.''' payload = {"model":model,"messages":messages,"tools":tools,"stream":stream, **kwargs} response = requests.post(self.BASE_URL+"/api/chat", json=payload, stream=stream) if stream: @@ -74,7 +72,7 @@ async def chat(self, model: str, messages: list, tools: list = [], stream=True, @modal.method() def generate(self, model: str, prompt: str, stream=True, **kwargs): - """Generate a response from the given user prompt.""" + '''Generate a response from the given user prompt.''' payload = {"model":model,"prompt":prompt,"stream":stream, **kwargs} response = requests.post(self.BASE_URL+"/api/generate", json=payload, stream=stream) if stream: @@ -87,7 +85,7 @@ def generate(self, model: str, prompt: str, stream=True, **kwargs): @modal.method() def embed(self, model: str, input: str, **kwargs): - """Embed a given text.""" + '''Embed a given text.''' payload = {"model":model,"input":input, **kwargs} response = requests.post(self.BASE_URL+"/api/embed", json=payload) embeddings = json.loads(response.content).get("embeddings") @@ -98,7 +96,7 @@ def embed(self, model: str, input: str, **kwargs): @modal.method() def embeddings(self, model: str, prompt: str, **kwargs): - """Embed a given text.""" + '''Embed a given text.''' payload = {"model":model,"prompt":prompt, **kwargs} print("payload \\|/") print(payload) @@ -117,14 +115,14 @@ def list(self): @modal.method() def list_running(self): - """List all running models.""" + '''List all running models.''' models = requests.get(self.BASE_URL+"/api/ps") print(models.content) return json.loads(models.content) if models else b'[]' @modal.method() def pull(self, model: str, stream = True, **kwargs): - """Pull a model from the model store.""" + '''Pull a model from the model store.''' payload = {"model":model, "stream":stream, **kwargs} response = requests.post(self.BASE_URL+"/api/pull", json=payload, stream=stream) if stream: @@ -135,7 +133,7 @@ def pull(self, model: str, stream = True, **kwargs): @modal.method() def create(self, model:str, modelfile: str = None, path: str = None, stream = True, **kwargs): - """Create a new model.""" + '''Create a new model.''' payload = {"model":model, "modelfile":modelfile, "path":path, "stream":stream, **kwargs} response = requests.post(self.BASE_URL+"/api/create", json=payload, stream=stream) if stream: @@ -186,7 +184,10 @@ async def embed(request: Request): @web_app.post("/api/embeddings") async def embed(request: Request): - '''DEPRECATED--Get vector embeddings for given text, use /api/embed instead, only maintained for backwards compatibility with llamaParse''' + ''' + DEPRECATED--Get vector embeddings for given text, use /api/embed instead, + only maintained for backwards compatibility with llamaParse + ''' print("vectorizing text") params = await request.json() res = Ollama().embeddings.remote(**params) @@ -225,7 +226,10 @@ async def pull(request: Request): @web_app.post("/api/create") async def create(request: Request): - '''Create a new model''' + ''' + Create a new model, based on parameters within a json request, such as + model name, path to model files, and whether to stream the progress response + ''' print("Creating model") ollama = Ollama() params = await request.json() @@ -272,7 +276,7 @@ def init_and_setup(): print("\ntest for streaming chat completion") res = ollama_client.chat(messages=messages, model=llm, stream=True) for chunk in res: - print(chunk.get('message').get("content"), end="", flush=True) + print(chunk.get('message').get("content"), end='', flush=True) print("\ntest for non streaming normal completion") @@ -281,7 +285,7 @@ def init_and_setup(): print("\ntest for streaming normal completion") res = ollama_client.generate(model=llm, prompt=prompt, stream=True) for chunk in res: - print(chunk.get('response'), end="", flush=True) + print(chunk.get('response'), end='', flush=True) print("\ntest for current embedding function") diff --git a/ollama_backend/requirements.txt b/ollama_backend/requirements.txt new file mode 100644 index 0000000..373ee11 --- /dev/null +++ b/ollama_backend/requirements.txt @@ -0,0 +1,5 @@ +fastapi[standard] +requests +pydantic +ollama +modal \ No newline at end of file