Skip to content

Commit

Permalink
change to only serve 1 model per instance
Browse files Browse the repository at this point in the history
  • Loading branch information
samos123 committed Oct 28, 2023
1 parent 8def3ec commit c6e5a04
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 49 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ docker run -e MODEL=multi-qa-MiniLM-L6-cos-v1 -p 8080:8080 -d \
ghcr.io/substratusai/sentence-transformers-api
```

In addition to preloading models, you can also specify other models at runtime. A single
API endpoint can serve multiple models in parallel.
Note that STAPI will only serve the model that it is preloaded with. You
should create another instance of STAPI to serve another model. The `model`
parameter as part of the request body is simply ignored.


## Integrations
It's easy to utilize the embedding server with various other tools because
Expand Down
47 changes: 33 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,30 @@
from sentence_transformers import SentenceTransformer

models: Dict[str, SentenceTransformer] = {}
default_model_name = os.getenv("MODEL", 'all-MiniLM-L6-v2')
model_name = os.getenv("MODEL", "all-MiniLM-L6-v2")


class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] = Field(examples=["substratus.ai provides the best LLM tools"])
model: str = Field(examples=[default_model_name], default=default_model_name)
input: Union[str, List[str]] = Field(
examples=["substratus.ai provides the best LLM tools"]
)
model: str = Field(
examples=[model_name],
default=model_name,
)


class EmbeddingData(BaseModel):
embedding: List[float]
index: int
object: str


class Usage(BaseModel):
prompt_tokens: int
total_tokens: int


class EmbeddingResponse(BaseModel):
data: List[EmbeddingData]
model: str
Expand All @@ -31,39 +40,49 @@ class EmbeddingResponse(BaseModel):

@asynccontextmanager
async def lifespan(app: FastAPI):
models[default_model_name] = SentenceTransformer(default_model_name)
models[model_name] = SentenceTransformer(model_name)
yield


app = FastAPI(lifespan=lifespan)


@app.post("/v1/embeddings")
async def embedding(item: EmbeddingRequest) -> EmbeddingResponse:
selected_model: SentenceTransformer = models.setdefault(item.model, SentenceTransformer(item.model))
model: SentenceTransformer = models[model_name]
if isinstance(item.input, str):
vectors = selected_model.encode(item.input)
vectors = model.encode(item.input)
tokens = len(vectors)
return EmbeddingResponse(
data=[EmbeddingData(embedding=vectors, index=0, object="embedding")],
model=item.model,
model=model_name,
usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
object="list"
object="list",
)
if isinstance(item.input, list):
embeddings = []
tokens = 0
for index, text_input in enumerate(item.input):
if not isinstance(text_input, str):
raise HTTPException(status_code=400, detail="input needs to be an array of strings or a string")
vectors = selected_model.encode(text_input)
raise HTTPException(
status_code=400,
detail="input needs to be an array of strings or a string",
)
vectors = model.encode(text_input)
tokens += len(vectors)
embeddings.append(EmbeddingData(embedding=vectors, index=index, object="embedding"))
embeddings.append(
EmbeddingData(embedding=vectors, index=index, object="embedding")
)
return EmbeddingResponse(
data=embeddings,
model=item.model,
model=model_name,
usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
object="list"
object="list",
)
raise HTTPException(status_code=400, detail="input needs to be an array of strings or a string")
raise HTTPException(
status_code=400, detail="input needs to be an array of strings or a string"
)


@app.get("/")
@app.get("/healthz")
Expand Down
72 changes: 39 additions & 33 deletions test_main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
from fastapi.testclient import TestClient
from main import app, default_model_name

client = TestClient(app)
from main import app, model_name


def test_read_healthz():
response = client.get("/healthz")
assert response.status_code == 200
with TestClient(app) as client:
response = client.get("/healthz")
assert response.status_code == 200


def test_embedding_str():
embedding_request = {
"input": "substratus.ai has some great LLM OSS projects for K8s",
"model": default_model_name
}
response = client.post("/v1/embeddings", json=embedding_request)
assert response.status_code == 200
embedding_response = response.json()
assert isinstance(embedding_response["data"], list)
assert isinstance(embedding_response["data"][0]["embedding"], list)
assert isinstance(embedding_response["data"][0]["embedding"][0], float)
with TestClient(app) as client:
embedding_request = {
"input": "substratus.ai has some great LLM OSS projects for K8s",
"model": model_name,
}
response = client.post("/v1/embeddings", json=embedding_request)
assert response.status_code == 200
embedding_response = response.json()
assert isinstance(embedding_response["data"], list)
assert isinstance(embedding_response["data"][0]["embedding"], list)
assert isinstance(embedding_response["data"][0]["embedding"][0], float)


def test_embedding_list_str():
embedding_request = {
"input": ["substratus.ai has some great LLM OSS projects for K8s", "2nd string"],
"model": default_model_name
}
response = client.post("/v1/embeddings", json=embedding_request)
assert response.status_code == 200
embedding_response = response.json()
assert isinstance(embedding_response["data"], list)
assert isinstance(embedding_response["data"][0]["embedding"], list)
assert isinstance(embedding_response["data"][0]["embedding"][0], float)

assert isinstance(embedding_response["data"], list)
assert isinstance(embedding_response["data"][1]["embedding"], list)
assert isinstance(embedding_response["data"][1]["embedding"][0], float)

embedding_1 = embedding_response["data"][0]["embedding"]
embedding_2 = embedding_response["data"][1]["embedding"]
assert embedding_1 != embedding_2
with TestClient(app) as client:
embedding_request = {
"input": [
"substratus.ai has some great LLM OSS projects for K8s",
"2nd string",
],
"model": model_name,
}
response = client.post("/v1/embeddings", json=embedding_request)
assert response.status_code == 200
embedding_response = response.json()
assert isinstance(embedding_response["data"], list)
assert isinstance(embedding_response["data"][0]["embedding"], list)
assert isinstance(embedding_response["data"][0]["embedding"][0], float)

assert isinstance(embedding_response["data"], list)
assert isinstance(embedding_response["data"][1]["embedding"], list)
assert isinstance(embedding_response["data"][1]["embedding"][0], float)

embedding_1 = embedding_response["data"][0]["embedding"]
embedding_2 = embedding_response["data"][1]["embedding"]
assert embedding_1 != embedding_2

0 comments on commit c6e5a04

Please sign in to comment.