Skip to content

Commit

Permalink
CHORE: added requirements.txt, cleaning up production server code
Browse files Browse the repository at this point in the history
  • Loading branch information
ProtoFaze committed Dec 10, 2024
1 parent 664b4b6 commit dec61be
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
36 changes: 20 additions & 16 deletions ollama_backend/ollama_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import subprocess
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, Request
from typing import Iterator, Union

app = FastAPI()

Expand All @@ -16,7 +15,7 @@
mount = modal.Mount.from_local_dir(local_path='modelfiles', remote_path="modelfiles")

def serve_ollama():
"""Ensure Ollama server is running."""
'''Ensure Ollama server is running.'''
subprocess.Popen(["ollama", "serve"])

image = (modal.Image
Expand All @@ -43,24 +42,23 @@ def serve_ollama():
)
class Ollama:
'''Ollama class for handling calls to the endpoint'''
# model: str = modal.parameter(init=True)
BASE_URL: str = "http://localhost:11434"

@modal.enter()
def init_model(self):
"""Start the Ollama server."""
'''Start the Ollama server.'''
print("Starting server")
serve_ollama()

@modal.method()
def warmup(self):
"""Warmup the model."""
'''Warmup the model.'''
ollama.generate(self.model)
return {"status": "ok"}

@modal.method()
async def chat(self, model: str, messages: list, tools: list = [], stream=True, **kwargs):
"""Handle chat interaction and stream the response."""
'''Handle chat interaction and stream the response.'''
payload = {"model":model,"messages":messages,"tools":tools,"stream":stream, **kwargs}
response = requests.post(self.BASE_URL+"/api/chat", json=payload, stream=stream)
if stream:
Expand All @@ -74,7 +72,7 @@ async def chat(self, model: str, messages: list, tools: list = [], stream=True,

@modal.method()
def generate(self, model: str, prompt: str, stream=True, **kwargs):
"""Generate a response from the given user prompt."""
'''Generate a response from the given user prompt.'''
payload = {"model":model,"prompt":prompt,"stream":stream, **kwargs}
response = requests.post(self.BASE_URL+"/api/generate", json=payload, stream=stream)
if stream:
Expand All @@ -87,7 +85,7 @@ def generate(self, model: str, prompt: str, stream=True, **kwargs):

@modal.method()
def embed(self, model: str, input: str, **kwargs):
"""Embed a given text."""
'''Embed a given text.'''
payload = {"model":model,"input":input, **kwargs}
response = requests.post(self.BASE_URL+"/api/embed", json=payload)
embeddings = json.loads(response.content).get("embeddings")
Expand All @@ -98,7 +96,7 @@ def embed(self, model: str, input: str, **kwargs):

@modal.method()
def embeddings(self, model: str, prompt: str, **kwargs):
"""Embed a given text."""
'''Embed a given text.'''
payload = {"model":model,"prompt":prompt, **kwargs}
print("payload \\|/")
print(payload)
Expand All @@ -117,14 +115,14 @@ def list(self):

@modal.method()
def list_running(self):
"""List all running models."""
'''List all running models.'''
models = requests.get(self.BASE_URL+"/api/ps")
print(models.content)
return json.loads(models.content) if models else b'[]'

@modal.method()
def pull(self, model: str, stream = True, **kwargs):
"""Pull a model from the model store."""
'''Pull a model from the model store.'''
payload = {"model":model, "stream":stream, **kwargs}
response = requests.post(self.BASE_URL+"/api/pull", json=payload, stream=stream)
if stream:
Expand All @@ -135,7 +133,7 @@ def pull(self, model: str, stream = True, **kwargs):

@modal.method()
def create(self, model:str, modelfile: str = None, path: str = None, stream = True, **kwargs):
"""Create a new model."""
'''Create a new model.'''
payload = {"model":model, "modelfile":modelfile, "path":path, "stream":stream, **kwargs}
response = requests.post(self.BASE_URL+"/api/create", json=payload, stream=stream)
if stream:
Expand Down Expand Up @@ -186,7 +184,10 @@ async def embed(request: Request):

@web_app.post("/api/embeddings")
async def embed(request: Request):
'''DEPRECATED--Get vector embeddings for given text, use /api/embed instead, only maintained for backwards compatibility with llamaParse'''
'''
DEPRECATED--Get vector embeddings for given text, use /api/embed instead,
only maintained for backwards compatibility with llamaParse
'''
print("vectorizing text")
params = await request.json()
res = Ollama().embeddings.remote(**params)
Expand Down Expand Up @@ -225,7 +226,10 @@ async def pull(request: Request):

@web_app.post("/api/create")
async def create(request: Request):
'''Create a new model'''
'''
Create a new model, based on parameters within a json request, such as
model name, path to model files, and whether to stream the progress response
'''
print("Creating model")
ollama = Ollama()
params = await request.json()
Expand Down Expand Up @@ -272,7 +276,7 @@ def init_and_setup():
print("\ntest for streaming chat completion")
res = ollama_client.chat(messages=messages, model=llm, stream=True)
for chunk in res:
print(chunk.get('message').get("content"), end="", flush=True)
print(chunk.get('message').get("content"), end='', flush=True)


print("\ntest for non streaming normal completion")
Expand All @@ -281,7 +285,7 @@ def init_and_setup():
print("\ntest for streaming normal completion")
res = ollama_client.generate(model=llm, prompt=prompt, stream=True)
for chunk in res:
print(chunk.get('response'), end="", flush=True)
print(chunk.get('response'), end='', flush=True)


print("\ntest for current embedding function")
Expand Down
5 changes: 5 additions & 0 deletions ollama_backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi[standard]
requests
pydantic
ollama
modal

0 comments on commit dec61be

Please sign in to comment.