Skip to content

Commit

Permalink
server: tests: assert embeddings are actually computed, make the embe…
Browse files Browse the repository at this point in the history
…ddings endpoint configurable.

Add logs to investigate why the CI server test job is not starting
  • Loading branch information
phymbert committed Feb 23, 2024
1 parent cba6d4e commit 1bd07e5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
1 change: 1 addition & 0 deletions examples/server/tests/features/server.feature
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Feature: llama.cpp server
And 42 as server seed
And 32 KV cache size
And 1 slots
And embeddings extraction
And 32 server max tokens to predict
Then the server is starting
Then the server is healthy
Expand Down
32 changes: 27 additions & 5 deletions examples/server/tests/features/steps/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import socket
import subprocess
import time
from contextlib import closing
from re import RegexFlag

Expand All @@ -21,13 +22,14 @@ def step_server_config(context, server_fqdn, server_port):

context.base_url = f'http://{context.server_fqdn}:{context.server_port}'

context.server_continuous_batching = False
context.model_alias = None
context.n_ctx = None
context.n_predict = None
context.n_server_predict = None
context.n_slots = None
context.server_api_key = None
context.server_continuous_batching = False
context.server_embeddings = False
context.server_seed = None
context.user_api_key = None

Expand Down Expand Up @@ -70,15 +72,26 @@ def step_server_n_predict(context, n_predict):
def step_server_continuous_batching(context):
context.server_continuous_batching = True

@step(u'embeddings extraction')
def step_server_embeddings(context):
context.server_embeddings = True


@step(u"the server is starting")
def step_start_server(context):
start_server_background(context)
attempts = 0
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((context.server_fqdn, context.server_port))
if result == 0:
print("server started!")
return
attempts += 1
if attempts > 20:
assert False, "server not started"
print("waiting for server to start...")
time.sleep(0.1)


@step(u"the server is {expecting_status}")
Expand Down Expand Up @@ -301,6 +314,11 @@ def step_compute_embedding(context):
@step(u'embeddings are generated')
def step_compute_embeddings(context):
assert len(context.embeddings) > 0
embeddings_computed = False
for emb in context.embeddings:
if emb != 0:
embeddings_computed = True
assert embeddings_computed, f"Embeddings: {context.embeddings}"


@step(u'an OAI compatible embeddings computation request for')
Expand Down Expand Up @@ -436,7 +454,8 @@ async def oai_chat_completions(user_prompt,
json=payload,
headers=headers) as response:
if enable_streaming:
print("payload", payload)
# FIXME: does not work; the server is generating only one token
print("DEBUG payload", payload)
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "text/event-stream"
Expand All @@ -453,7 +472,7 @@ async def oai_chat_completions(user_prompt,
if 'content' in delta:
completion_response['content'] += delta['content']
completion_response['timings']['predicted_n'] += 1
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}")
print(f"DEBUG completion_response: {completion_response}")
else:
if expect_api_error is None or not expect_api_error:
assert response.status == 200
Expand Down Expand Up @@ -500,7 +519,7 @@ async def oai_chat_completions(user_prompt,
'predicted_n': chat_completion.usage.completion_tokens
}
}
print("OAI response formatted to llama.cpp", completion_response)
print("OAI response formatted to llama.cpp:", completion_response)
return completion_response


Expand Down Expand Up @@ -567,7 +586,7 @@ async def wait_for_health_status(context,
# Sometimes health requests are triggered after completions are predicted
if expected_http_status_code == 503:
if len(context.completions) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks,"
print("\x1b[33;42mWARNING: forcing concurrents completions tasks,"
" busy health check missed\x1b[0m")
n_completions = await gather_concurrent_completions_tasks(context)
if n_completions > 0:
Expand Down Expand Up @@ -604,6 +623,8 @@ def start_server_background(context):
]
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings:
server_args.append('--embedding')
if context.model_alias is not None:
server_args.extend(['--alias', context.model_alias])
if context.server_seed is not None:
Expand All @@ -620,3 +641,4 @@ def start_server_background(context):
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
close_fds=True)
print(f"server pid={context.server_process.pid}")

0 comments on commit 1bd07e5

Please sign in to comment.