diff --git a/README.md b/README.md index 50002ed..0d97d82 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,7 @@ ![Build & Tests](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/build_and_tests.yaml/badge.svg) ![Wheel setup](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/wheel.yaml/badge.svg) -benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, it is focused on LLMs served via [vllm](https://github.com/vllm-project/vllm) and more specifically via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables. - +benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, two backends are possible : `mistral` and [vLLM](https://github.com/vllm-project/vllm) (via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables). ## Installation It is advised to clone the repository in order to get the datasets used for the benchmarks (you can find them in `src/benchmark_llm_serving/datasets`) and build it from source: diff --git a/src/benchmark_llm_serving/backends.py b/src/benchmark_llm_serving/backends.py index 560fede..0cc344b 100644 --- a/src/benchmark_llm_serving/backends.py +++ b/src/benchmark_llm_serving/backends.py @@ -3,149 +3,197 @@ from benchmark_llm_serving.io_classes import QueryOutput, QueryInput -IMPLEMENTED_BACKENDS = "'happy_vllm', 'mistral'" +class BackEnd(): + TEMPERATURE = 0 + REPETITION_PENALTY = 1.2 -def get_payload(query_input: QueryInput, args: argparse.Namespace) -> dict: - """Gets the payload to give to the model + def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]"): + self.backend_name = backend_name + self.chunk_prefix = chunk_prefix + self.last_chunk = last_chunk - Args: - query_input (QueryInput) : The query input to use - args (argparse.Namespace) : The cli args + def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict: + """Gets the payload to give to the model - Returns: - dict : The payload - """ - temperature = 0 - repetition_penalty = 1.2 - if args.backend == "happy_vllm": + Args: + query_input (QueryInput) : The query input to use + args (argparse.Namespace) : The cli args + + Returns: + dict : The payload + """ + raise NotImplemented("The subclass should implement this method") + + def get_newly_generated_text(self, json_chunk: str) -> str: + """Gets the newly generated text + + Args: + json_chunk (dict) : The chunk containing the generated text + + Returns: + str : The newly generated text + """ + raise NotImplemented("The subclass should implement this method") + + def test_chunk_validity(self, chunk: str) -> bool: + """Tests if the chunk is valid or should not be considered. + + Args: + chunk (str) : The chunk to consider + + Returns: + bool : Whether the chunk is valid or not + """ + return True + + def get_completions_headers(self) -> dict: + """Gets the headers (depending on the backend) to use for the request + + Returns: + dict: The headers + + """ + return {} + + def remove_response_prefix(self, chunk: str) -> str: + """Removes the prefix in the response of a model + + Args: + chunk (str) : The chunk received + + Returns: + str : The string without the prefix + """ + return chunk.removeprefix(self.chunk_prefix) + + def check_end_of_stream(self, chunk: str) -> bool: + """Checks whether this is the last chunk of the stream + + Args: + chunk (str) : The chunk to test + + Returns: + bool : Whether it is the last chunk of the stream + """ + return chunk == self.last_chunk + + def add_prompt_length(self, json_chunk: dict, output: QueryOutput) -> None: + """Add the prompt length to the QueryOutput if the key "usage" is in the chunk + + Args: + json_chunk (dict) : The chunk containing the prompt length + output (QueryOutput) : The output + """ + if "usage" in json_chunk: + if json_chunk['usage'] is not None: + output.prompt_length = json_chunk['usage']['prompt_tokens'] + + + +class BackendHappyVllm(BackEnd): + + def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict: + """Gets the payload to give to the model + + Args: + query_input (QueryInput) : The query input to use + args (argparse.Namespace) : The cli args + + Returns: + dict : The payload + """ return {"prompt": query_input.prompt, "model": args.model, "max_tokens": args.output_length, "min_tokens": args.output_length, - "temperature": temperature, - "repetition_penalty": repetition_penalty, + "temperature": self.TEMPERATURE, + "repetition_penalty": self.REPETITION_PENALTY, "stream": True, "stream_options": {"include_usage": True} } - elif args.backend == "mistral": + + def get_newly_generated_text(self, json_chunk: str) -> str: + """Gets the newly generated text + + Args: + json_chunk (dict) : The chunk containing the generated text + + Returns: + str : The newly generated text + """ + if len(json_chunk['choices']): + data = json_chunk['choices'][0]['text'] + return data + else: + return "" + + +class BackEndMistral(BackEnd): + + def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict: + """Gets the payload to give to the model + + Args: + query_input (QueryInput) : The query input to use + args (argparse.Namespace) : The cli args + + Returns: + dict : The payload + """ return {"messages": [{"role": "user", "content": query_input.prompt}], "model": args.model, "max_tokens": args.output_length, "min_tokens": args.output_length, - "temperature": temperature, + "temperature": self.TEMPERATURE, "stream": True } - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") - + + def test_chunk_validity(self, chunk: str) -> bool: + """Tests if the chunk is valid or should not be considered. -def test_chunk_validity(chunk: str, args: argparse.Namespace) -> bool: - """Tests if the chunk is valid or should not be considered. + Args: + chunk (str) : The chunk to consider - Args: - chunk (str) : The chunk to consider - args (argparse.Namespace) : The cli args - - Returns: - bool : Whether the chunk is valid or not - """ - if args.backend in ["happy_vllm"]: - return True - elif args.backend in ["mistral"]: + Returns: + bool : Whether the chunk is valid or not + """ if chunk[:4] == "tok-": return False else: return True - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") + def get_completions_headers(self) -> dict: + """Gets the headers (depending on the backend) to use for the request -def get_completions_headers(args: argparse.Namespace) -> dict: - """Gets the headers (depending on the backend) to use for the request + Returns: + dict: The headers - Args: - args (argparse.Namespace) : The cli args - - Returns: - dict: The headers - - """ - if args.backend in ["happy_vllm"]: - return {} - elif args.backend == "mistral": + """ return {"Accept": "application/json", "Content-Type": "application/json"} - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") - -def decode_remove_response_prefix(chunk_bytes: bytes, args: argparse.Namespace) -> str: - """Removes the prefix in the response of a model and converts the bytes in str + def get_newly_generated_text(self, json_chunk: str) -> str: + """Gets the newly generated text - Args: - chunk_bytes (bytes) : The chunk received - args (argparse.Namespace) : The cli args + Args: + json_chunk (dict) : The chunk containing the generated text - Returns: - str : The decoded string without the prefix - """ - chunk = chunk_bytes.decode("utf-8") - if args.backend in ["happy_vllm", "mistral"]: - return chunk.removeprefix("data: ") - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") - - -def check_end_of_stream(chunk: str, args: argparse.Namespace) -> bool: - """Checks if this is the last chunk of the stream - - Args: - chunk (str) : The chunk to test - args (argparse.Namespace) : The cli args - - Returns: - bool : Whether it is the last chunk of the stream - """ - if args.backend in ["happy_vllm", "mistral"]: - return chunk == "[DONE]" - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") - - -def get_newly_generated_text(json_chunk: dict, args: argparse.Namespace) -> str: - """Gets the newly generated text - - Args: - json_chunk (dict) : The chunk containing the generated text - args (argparse.Namespace) : The cli args - - Returns: - str : The newly generated text - """ - if args.backend == "happy_vllm": - if len(json_chunk['choices']): - data = json_chunk['choices'][0]['text'] - return data - elif args.backend == "mistral": + Returns: + str : The newly generated text + """ if len(json_chunk['choices']): data = json_chunk['choices'][0]['delta']["content"] return data - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") - return "" - + else: + return "" -def add_prompt_length(json_chunk: dict, output: QueryOutput, args: argparse.Namespace) -> None: - """Add the prompt length to the QueryOutput - Args: - json_chunk (dict) : The chunk containing the prompt length - args (argparse.Namespace) : The cli args - """ - if args.backend in ["happy_vllm", 'mistral']: - if "usage" in json_chunk: - if json_chunk['usage'] is not None: - output.prompt_length = json_chunk['usage']['prompt_tokens'] - else: - raise ValueError(f"The specified backend {args.backend} is not implemented. Please use one of the following : {IMPLEMENTED_BACKENDS}") \ No newline at end of file +def get_backend(backend_name: str) -> BackEnd: + implemented_backends = ["mistral", "happy_vllm"] + if backend_name not in implemented_backends: + raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}") + if backend_name == "happy_vllm": + return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]") + if backend_name == "mistral": + return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]") \ No newline at end of file diff --git a/src/benchmark_llm_serving/bench_suite.py b/src/benchmark_llm_serving/bench_suite.py index f43a42f..49f0381 100644 --- a/src/benchmark_llm_serving/bench_suite.py +++ b/src/benchmark_llm_serving/bench_suite.py @@ -12,6 +12,7 @@ from benchmark_llm_serving import utils from benchmark_llm_serving.io_classes import QueryInput from benchmark_llm_serving.make_readmes import make_readme +from benchmark_llm_serving.backends import get_backend, BackEnd from benchmark_llm_serving.make_graphs import draw_and_save_graphs from benchmark_llm_serving.benchmark import launch_benchmark, augment_dataset from benchmark_llm_serving.utils_args import get_parser_base_arguments, add_arguments_to_parser @@ -143,6 +144,8 @@ def main(): for input_length in input_lengths: for output_length in output_lengths: input_output_lengths.append((input_length, output_length)) + + backend = get_backend(args.backend) # Launch the benchmark for prompt ingestion speed now = utils.get_now() @@ -162,13 +165,13 @@ def main(): logger.info(f"{now} Benchmark for the prompt ingestion speed : instance {i} ") args.output_file = os.path.join(raw_results_folder, f"prompt_ingestion_{i}.json") dataset = add_prefixes_to_dataset(datasets[args.prompt_length], 4) - launch_benchmark(args, dataset, suite_id) + launch_benchmark(args, dataset, suite_id, backend=backend) now = utils.get_now() logger.info(f"{now} Benchmark for the prompt ingestion speed : instance {i} : DONE") now = utils.get_now() logger.info(f"{now} Benchmark for the prompt ingestion speed : DONE") - if args.backend == "happy_vllm": + if backend.backend_name == "happy_vllm": # Launch the benchmark for the KV cache profile now = utils.get_now() logger.info(f"{now} Beginning the benchmarks for the KV cache profile") @@ -185,7 +188,7 @@ def main(): now = utils.get_now() dataset = add_prefixes_to_dataset(datasets[args.prompt_length], 4) logger.info(f"{now} Beginning the benchmark for the KV cache profile, input length : {input_length}, output_length : {output_length}") - launch_benchmark(args, dataset, suite_id) + launch_benchmark(args, dataset, suite_id, backend=backend) now = utils.get_now() logger.info(f"{now} Benchmark for the KV cache profile, input length : {input_length}, output_length : {output_length} : DONE") now = utils.get_now() @@ -214,7 +217,7 @@ def main(): now = utils.get_now() logger.info(f"{now} Benchmarks for the generation speed, input length : {input_length}, output_length : {output_length}, nb_requests : {nb_constant_requests}") dataset = add_prefixes_to_dataset(datasets[args.prompt_length], 4) - launch_benchmark(args, dataset, suite_id) + launch_benchmark(args, dataset, suite_id, backend=backend) now = utils.get_now() logger.info(f"{now} Benchmarks for the generation speed, input length : {input_length}, output_length : {output_length}, nb_requests : {nb_constant_requests} : DONE") current_timestamp = datetime.now().timestamp() @@ -228,7 +231,7 @@ def main(): now = utils.get_now() logger.info(f"{now} Drawing graphs") draw_and_save_graphs(output_folder, speed_threshold=args.speed_threshold, gpu_name=args.gpu_name, - min_number_of_valid_queries=args.min_number_of_valid_queries, backend=args.backend) + min_number_of_valid_queries=args.min_number_of_valid_queries, backend=backend) now = utils.get_now() logger.info(f"{now} Drawing graphs : DONE") diff --git a/src/benchmark_llm_serving/benchmark.py b/src/benchmark_llm_serving/benchmark.py index 0a45d75..60464bc 100644 --- a/src/benchmark_llm_serving/benchmark.py +++ b/src/benchmark_llm_serving/benchmark.py @@ -10,8 +10,8 @@ from typing import List, Tuple, Union, Any from benchmark_llm_serving import utils -from benchmark_llm_serving import backends from benchmark_llm_serving.utils_args import parse_args +from benchmark_llm_serving.backends import get_backend, BackEnd from benchmark_llm_serving.io_classes import QueryOutput, QueryInput from benchmark_llm_serving.query_profiles.query_functions import query_function from benchmark_llm_serving.query_profiles.constant_number import get_benchmark_results_constant_number @@ -24,13 +24,14 @@ async def get_benchmark_results(queries_dataset: List[QueryInput], args: argparse.Namespace, - logger: logging.Logger) -> Tuple[List[QueryOutput], List[dict]]: + logger: logging.Logger, backend: BackEnd) -> Tuple[List[QueryOutput], List[dict]]: """Gets the results of the benchmark Args: queries_dataset (list) : The queries we want to use args (argparse.Namespace) : The CLI args logger (logging.Logger) : The logger + backend (Backend) : The backend to consider Returns: list[QueryOutput] : The list of the result for each query @@ -41,20 +42,20 @@ async def get_benchmark_results(queries_dataset: List[QueryInput], args: argpars # Make one query in order to be sure that everything is ok query_input = QueryInput(prompt="Hey", internal_id=-1) - payload = backends.get_payload(query_input, args) - headers = backends.get_completions_headers(args) + payload = backend.get_payload(query_input, args) + headers = backend.get_completions_headers() response = requests.post(completions_url, json=payload, timeout=100, headers=headers) status_code = response.status_code # if status_code != 200: # raise ValueError(f"The status code of the response is {status_code} instead of 200") if args.query_profile == "constant_number_of_queries": - results, all_live_metrics = await get_benchmark_results_constant_number(queries_dataset, args, completions_url, metrics_url, logger) + results, all_live_metrics = await get_benchmark_results_constant_number(queries_dataset, args, completions_url, metrics_url, logger, backend) elif args.query_profile == "growing_requests": - results, all_live_metrics = await get_benchmark_results_growing_requests(queries_dataset, args, completions_url, metrics_url, logger) + results, all_live_metrics = await get_benchmark_results_growing_requests(queries_dataset, args, completions_url, metrics_url, logger, backend) else: results, all_live_metrics = await get_benchmark_results_scheduled_requests(queries_dataset, args, - completions_url, metrics_url, logger) + completions_url, metrics_url, logger, backend) return results, all_live_metrics @@ -184,13 +185,14 @@ def get_aggregated_metrics(benchmark_results: List[QueryOutput]) -> dict: return aggregated_metrics -def launch_benchmark(args: argparse.Namespace, provided_dataset: Union[List[str], None] = None, suite_id: Union[str, None] = None): +def launch_benchmark(args: argparse.Namespace, provided_dataset: Union[List[str], None] = None, suite_id: Union[str, None] = None, backend: Union[BackEnd, None] = None): """Calculates the results of a benchmark. We can explicitly give another dataset. Args: args (argparse.Namespace) : The cli args provided_dataset (list) : If provided, replace the loading suite_id (str) : The id to identify several benchmarks launched by the same bench suite + backend (BackEnd) : The backend to consider. If None, will infer it from args """ now = utils.get_now() @@ -204,9 +206,12 @@ def launch_benchmark(args: argparse.Namespace, provided_dataset: Union[List[str] if args.base_url is None: args.base_url = f"http://{args.host}:{args.port}" + if backend is None: + backend = get_backend(args.backend) + if args.model is None: - if args.backend != "happy_vllm": - raise ValueError(f"No model is specified and the backend is not happy_vllm (it is '{args.backend}'). Please provide a model name") + if backend.backend_name != "happy_vllm": + raise ValueError(f"No model is specified and the backend is not happy_vllm (it is '{backend.backend_name}'). Please provide a model name") args.model = get_model_name_from_info_endpoint(args) if args.output_length is None: @@ -226,7 +231,7 @@ def launch_benchmark(args: argparse.Namespace, provided_dataset: Union[List[str] "model_name": args.model_name } - if args.backend == "happy_vllm": + if backend.backend_name == "happy_vllm": parameters = add_application_parameters(parameters, args) if parameters["model_name"] is None: parameters["model_name"] = parameters["model"] @@ -248,7 +253,7 @@ def launch_benchmark(args: argparse.Namespace, provided_dataset: Union[List[str] now = utils.get_now() logger.info(f"{now} Beginning the requests to the completions endpoint") start_timestamp = datetime.now().timestamp() - benchmark_results, all_live_metrics = asyncio.run(get_benchmark_results(queries_dataset, args, logger)) + benchmark_results, all_live_metrics = asyncio.run(get_benchmark_results(queries_dataset, args, logger, backend)) end_timestamp = datetime.now().timestamp() now = utils.get_now() diff --git a/src/benchmark_llm_serving/make_graphs.py b/src/benchmark_llm_serving/make_graphs.py index 5ad51d8..b523dec 100644 --- a/src/benchmark_llm_serving/make_graphs.py +++ b/src/benchmark_llm_serving/make_graphs.py @@ -10,6 +10,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from benchmark_llm_serving import utils +from benchmark_llm_serving.backends import get_backend, BackEnd logger = logging.getLogger("Making the graphs") logging.basicConfig(level=logging.INFO) @@ -92,7 +93,7 @@ def make_prompt_ingestion_graph(files: dict, report_folder: str) -> None: def make_speed_generation_graph_for_one_input_output(input_length: int, output_length: int, speed_generation_files: dict, report_folder: str, - speed_threshold: float, backend: str) -> dict: + speed_threshold: float, backend: BackEnd) -> dict: """Draws the speed generation graph and save the corresponding data for a couple of input length/ output length. Also returns the corresponding speed and kv cache thresholds @@ -103,7 +104,7 @@ def make_speed_generation_graph_for_one_input_output(input_length: int, output_l speed_generation_files (dict) : The file containing the results of the speed generation benchmarks report_folder (str) : The folder where the report should be written speed_threshold (float) : The accepted speed generation to fix the threshold - backend (str) : The backend + backend (BackEnd) : The backend Returns: dict : The accepted thresholds @@ -169,7 +170,7 @@ def make_speed_generation_graph_for_one_input_output(input_length: int, output_l ax3.set_ylabel('Time to first token (ms)', fontsize='14') - if backend == "happy_vllm": + if backend.backend_name == "happy_vllm": ax2: Any = ax1.twinx() ax2.set_ylabel('Max KV cache percentage', fontsize='14') ax2.set_ylim((0, 1.0)) @@ -179,7 +180,7 @@ def make_speed_generation_graph_for_one_input_output(input_length: int, output_l yerr=[speed_generation_lower_percentiles, speed_generation_upper_percentiles], fmt='b-o', capsize=4, label="Speed generation") - if backend == "happy_vllm": + if backend.backend_name == "happy_vllm": # Max KV cache plot max_kv_cache_graph = ax2.plot(parallel_requests_nbs, max_kv_cache, color='green', linestyle="--", label="Max KV cache") # Time to first token generation plot @@ -187,7 +188,7 @@ def make_speed_generation_graph_for_one_input_output(input_length: int, output_l yerr=[time_to_first_token_lower_percentiles, time_to_first_token_upper_percentiles], fmt='r-o', capsize=4, label="Time to first token") - if backend == "happy_vllm": + if backend.backend_name == "happy_vllm": curves = [speed_generation_graph, max_kv_cache_graph[0], time_to_first_token_generation_graph] else: curves = [speed_generation_graph, time_to_first_token_generation_graph] @@ -212,7 +213,7 @@ def make_speed_generation_graph_for_one_input_output(input_length: int, output_l max_requests_speed = max(speed_treshold_reached) else: max_requests_speed = 0 - if backend == "happy_vllm": + if backend.backend_name == "happy_vllm": # KV cache threshold kv_treshold_not_reached = [key for key, value in data_summary.items() if value["max_kv_cache"] < 0.95] if len(kv_treshold_not_reached): @@ -224,7 +225,7 @@ def make_speed_generation_graph_for_one_input_output(input_length: int, output_l return {"speed_threshold": max_requests_speed} -def make_speed_generation_graphs(files: dict, report_folder: str, speed_threshold: float, backend: str): +def make_speed_generation_graphs(files: dict, report_folder: str, speed_threshold: float, backend: BackEnd): """Draws the speed generation graphs and save the corresponding data. Also saves the thresholds, namely the accepted number of parallel requests with a KV cache inferior to 1.0 and a speed generation above the speed_threshold @@ -233,7 +234,7 @@ def make_speed_generation_graphs(files: dict, report_folder: str, speed_threshol files (dict) : The files containing the results of the benchmarks report_folder (str) : The folder where the report should be written speed_threshold (float) : The accepted speed generation to fix the threshold - backend (str) : The backend + backend (BackEnd) : The backend """ speed_generation_files = {key: value for key, value in files.items() if 'generation_speed' in key} thresholds = [] @@ -422,7 +423,7 @@ def save_common_parameters(files: dict, report_folder: str, gpu_name: str): def draw_and_save_graphs(output_folder: str, speed_threshold: float = 20.0, gpu_name: Union[str, None] = None, - min_number_of_valid_queries: int = 50, backend: str = "happy_vllm"): + min_number_of_valid_queries: int = 50, backend: Union[BackEnd, str] = "happy_vllm"): """Draws and saves all the graphs and corresponding data for benchmark results obtained via bench_suite.py @@ -432,6 +433,9 @@ def draw_and_save_graphs(output_folder: str, speed_threshold: float = 20.0, gpu_ gpu_name (str) : The name of the gpu backend (str) : The backend """ + if isinstance(backend, str): + backend = get_backend(backend) + # Manage output path if not os.path.isabs(output_folder): current_directory = Path(os.path.dirname(os.path.realpath(__file__))) @@ -467,7 +471,7 @@ def draw_and_save_graphs(output_folder: str, speed_threshold: float = 20.0, gpu_ now = utils.get_now() logger.info(f"{now} Making speed generation graphs") make_speed_generation_graphs(files, report_folder, speed_threshold, backend) - if backend == "happy_vllm": + if backend.backend_name == "happy_vllm": now = utils.get_now() logger.info(f"{now} Making kv cache profile graphs") make_kv_cache_profile_graphs(files, report_folder) diff --git a/src/benchmark_llm_serving/query_profiles/constant_number.py b/src/benchmark_llm_serving/query_profiles/constant_number.py index cc26362..ef7aa3c 100644 --- a/src/benchmark_llm_serving/query_profiles/constant_number.py +++ b/src/benchmark_llm_serving/query_profiles/constant_number.py @@ -7,13 +7,14 @@ from typing import List, Tuple from benchmark_llm_serving.utils import get_now +from benchmark_llm_serving.backends import BackEnd from benchmark_llm_serving.utils_metrics import get_live_metrics from benchmark_llm_serving.io_classes import QueryOutput, QueryInput from benchmark_llm_serving.query_profiles.query_functions import query_function async def worker_func(session: aiohttp.ClientSession, queue: asyncio.Queue, completions_url: str, results: List[QueryOutput], - args: argparse.Namespace, logger: logging.Logger) -> None: + args: argparse.Namespace, logger: logging.Logger, backend: BackEnd) -> None: """Queries the completions API to get the output using a worker and a queue Args: @@ -23,6 +24,7 @@ async def worker_func(session: aiohttp.ClientSession, queue: asyncio.Queue, comp results (str) : The list of results to which we will add the output args (argparse.Namespace) : The cli args logger (logging.Logger) : The logger + backend (Backend) : The backend to consider """ while True: query_input = await queue.get() @@ -34,7 +36,7 @@ async def worker_func(session: aiohttp.ClientSession, queue: asyncio.Queue, comp await asyncio.sleep(random.uniform(0.0, 0.1)) else: await asyncio.sleep(random.uniform(0.0, 0.02)) - await query_function(query_input, session, completions_url, results, args) + await query_function(query_input, session, completions_url, results, args, backend) if len(results) % int(args.max_queries / 10) == 0: now = get_now() logger.info(f'{now} {len(results)} queries have been completed') @@ -73,7 +75,7 @@ def continue_condition(current_timestamp: float, start_queries_timestamp: float, async def get_benchmark_results_constant_number(queries_dataset: List[QueryInput], args: argparse.Namespace, completions_url: str, - metrics_url: str, logger: logging.Logger) -> Tuple[List[QueryOutput], List[dict]]: + metrics_url: str, logger: logging.Logger, backend: BackEnd) -> Tuple[List[QueryOutput], List[dict]]: """Gets the results for the benchmark and the live metrics, using workers so that there are always the same number of queries launched @@ -83,6 +85,7 @@ async def get_benchmark_results_constant_number(queries_dataset: List[QueryInput completions_url (str) : The url of the completions API metrics_url (str) : The url to the /metrics endpoint logger (logging.Logger) : The logger + backend (Backend) : The backend to consider Returns: list[QueryOutput] : The list of the result for each query @@ -95,7 +98,7 @@ async def get_benchmark_results_constant_number(queries_dataset: List[QueryInput connector = aiohttp.TCPConnector(limit=10000) async with aiohttp.ClientSession(connector=connector) as session: # Create workers - workers = [asyncio.create_task(worker_func(session, queue, completions_url, results, args, logger)) + workers = [asyncio.create_task(worker_func(session, queue, completions_url, results, args, logger, backend)) for _ in range(args.n_workers)] # Query the /metrics endpoint for one second before adding queries to the queue for i in range(int(1/args.step_live_metrics)): diff --git a/src/benchmark_llm_serving/query_profiles/growing_requests.py b/src/benchmark_llm_serving/query_profiles/growing_requests.py index 36ec731..b7a7565 100644 --- a/src/benchmark_llm_serving/query_profiles/growing_requests.py +++ b/src/benchmark_llm_serving/query_profiles/growing_requests.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import List, Tuple +from benchmark_llm_serving.backends import BackEnd from benchmark_llm_serving.utils import tasks_are_done, get_now from benchmark_llm_serving.utils_metrics import get_live_metrics from benchmark_llm_serving.io_classes import QueryOutput, QueryInput @@ -38,7 +39,7 @@ def continue_condition(current_timestamp: float, start_queries_timestamp: float, async def get_benchmark_results_growing_requests(queries_dataset: List[QueryInput], args: argparse.Namespace, completions_url: str, - metrics_url: str, logger: logging.Logger) -> Tuple[List[QueryOutput], List[dict]]: + metrics_url: str, logger: logging.Logger, backend: BackEnd) -> Tuple[List[QueryOutput], List[dict]]: """Gets the results for the benchmark and the live metrics, using a growing number of queries. First one is sent, then when done, two are sent, then when they are done, three are sent, etc. @@ -48,6 +49,7 @@ async def get_benchmark_results_growing_requests(queries_dataset: List[QueryInpu completions_url (str) : The url of the completions API metrics_url (str) : The url to the /metrics endpoint logger (logging.Logger) : The logger + backend (Backend) : The backend to consider Returns: list[QueryOutput] : The list of the result for each query @@ -73,7 +75,7 @@ async def get_benchmark_results_growing_requests(queries_dataset: List[QueryInpu await asyncio.sleep(0.5) now = get_now() logger.info(f"{now} Launching {n} queries in parallel") - tasks += [asyncio.create_task(query_function(query_input, session, completions_url, results, args)) + tasks += [asyncio.create_task(query_function(query_input, session, completions_url, results, args, backend)) for query_input in queries_dataset[nb_queries_launched: nb_queries_launched + n]] nb_queries_launched += n # While we wait for the tasks to be done, we query the /metrics endpoint diff --git a/src/benchmark_llm_serving/query_profiles/query_functions.py b/src/benchmark_llm_serving/query_profiles/query_functions.py index 734dc7e..b2ee34c 100644 --- a/src/benchmark_llm_serving/query_profiles/query_functions.py +++ b/src/benchmark_llm_serving/query_profiles/query_functions.py @@ -6,12 +6,12 @@ from typing import List from datetime import datetime -from benchmark_llm_serving import backends +from benchmark_llm_serving.backends import BackEnd from benchmark_llm_serving.io_classes import QueryOutput, QueryInput async def query_function(query_input: QueryInput, session: aiohttp.ClientSession, completions_url: str, results: List[QueryOutput], - args: argparse.Namespace) -> QueryOutput: + args: argparse.Namespace, backend: BackEnd) -> QueryOutput: """Queries the completions API to get the output Args: @@ -20,12 +20,13 @@ async def query_function(query_input: QueryInput, session: aiohttp.ClientSession completions_url (str) : The url of the completions API results (list) : The list of results to which we will add the output args (argparse.Namespace) : The cli args + backend (Backend) : The backend to consider Returns: QueryOutput : The output of the query """ - body = backends.get_payload(query_input, args) - headers = backends.get_completions_headers(args) + body = backend.get_payload(query_input, args) + headers = backend.get_completions_headers() output = QueryOutput() output.starting_timestamp = datetime.now().timestamp() output.prompt = query_input.prompt @@ -39,22 +40,23 @@ async def query_function(query_input: QueryInput, session: aiohttp.ClientSession continue # Some backends add a prefix to the response. We remove it - chunk = backends.decode_remove_response_prefix(chunk_bytes, args) + chunk = chunk_bytes.decode("utf-8") + chunk = backend.remove_response_prefix(chunk) # Some backends send useless messages. We don't consider them - if backends.test_chunk_validity(chunk, args): + if backend.test_chunk_validity(chunk): # If the stream is ending, we save the timestamp as ending time - if backends.check_end_of_stream(chunk, args): + if backend.check_end_of_stream(chunk): output.ending_timestamp = datetime.now().timestamp() output.success = True # Otherwise, we add the response to the already generated text else: timestamp = datetime.now().timestamp() json_chunk = json.loads(chunk) - newly_generated_text = backends.get_newly_generated_text(json_chunk, args) + newly_generated_text = backend.get_newly_generated_text(json_chunk) if len(newly_generated_text): output.timestamp_of_tokens_arrival.append(timestamp) output.generated_text += newly_generated_text - backends.add_prompt_length(json_chunk, output, args) + backend.add_prompt_length(json_chunk, output) else: output.success = False output.error = response.reason or "" diff --git a/src/benchmark_llm_serving/query_profiles/scheduled_requests.py b/src/benchmark_llm_serving/query_profiles/scheduled_requests.py index b37bb16..d909700 100644 --- a/src/benchmark_llm_serving/query_profiles/scheduled_requests.py +++ b/src/benchmark_llm_serving/query_profiles/scheduled_requests.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import List, Tuple +from benchmark_llm_serving.backends import BackEnd from benchmark_llm_serving.utils import tasks_are_done, get_now from benchmark_llm_serving.utils_metrics import get_live_metrics from benchmark_llm_serving.io_classes import QueryOutput, QueryInput @@ -110,7 +111,7 @@ def continue_condition(current_query_index_to_launch: int, max_queries_number: i async def get_benchmark_results_scheduled_requests(queries_dataset: List[QueryInput], args: argparse.Namespace, completions_url: str, - metrics_url: str, logger: logging.Logger) -> Tuple[List[QueryOutput], List[dict]]: + metrics_url: str, logger: logging.Logger, backend: BackEnd) -> Tuple[List[QueryOutput], List[dict]]: """Gets the results for the benchmark and the live metrics, using scheduled queries ie, queries whose timestamp we can calculate before actually launching the queries. @@ -120,6 +121,7 @@ async def get_benchmark_results_scheduled_requests(queries_dataset: List[QueryIn completions_url (str) : The url of the completions API metrics_url (str) : The url to the /metrics endpoint logger (logging.Logger) : The logger + backend (Backend) : The backend to consider Returns: list[QueryOutput] : The list of the result for each query @@ -160,7 +162,7 @@ async def get_benchmark_results_scheduled_requests(queries_dataset: List[QueryIn if current_query_index_to_launch // int(args.max_queries / 10) != old_query_index_to_launch // int(args.max_queries / 10): now = get_now() logger.info(f"{now} {current_query_index_to_launch} requests in total have been launched") - tasks += [asyncio.create_task(query_function(query_input, session, completions_url, results, args)) for query_input in queries_to_launch] + tasks += [asyncio.create_task(query_function(query_input, session, completions_url, results, args, backend)) for query_input in queries_to_launch] asyncio.create_task(get_live_metrics(session, metrics_url, all_live_metrics, args)) await asyncio.sleep(args.step_live_metrics) diff --git a/tests/test_backends.py b/tests/test_backends.py index fccaf23..9a11fa7 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -4,8 +4,27 @@ from benchmark_llm_serving import backends from benchmark_llm_serving.io_classes import QueryOutput, QueryInput -def test_get_payload(): + +def test_get_backend(): # happy_vllm backend + backend = backends.get_backend("happy_vllm") + assert isinstance(backend, backends.BackendHappyVllm) + assert backend.backend_name == "happy_vllm" + assert backend.chunk_prefix == "data: " + assert backend.last_chunk == "[DONE]" + # mistral backend + backend = backends.get_backend("mistral") + assert isinstance(backend, backends.BackEndMistral) + assert backend.backend_name == "mistral" + assert backend.chunk_prefix == "data: " + assert backend.last_chunk == "[DONE]" + # ValueError + with pytest.raises(ValueError): + backends.get_backend("backend_not_implemented") + + +def test_backend_happy_vllm_get_payload(): + backend = backends.get_backend("happy_vllm") prompts_list = ["Hey. How are you?", "Fine, you ?"] model_list = ["My_awesome_model", "yet_another_model"] output_length_list = [100, 1000] @@ -13,8 +32,8 @@ def test_get_payload(): for model in model_list: for output_length in output_length_list: query_input = QueryInput(prompt=prompt, internal_id=0) - args = argparse.Namespace(backend="happy_vllm", model=model, output_length=output_length) - payload = backends.get_payload(query_input, args) + args = argparse.Namespace(model=model, output_length=output_length) + payload = backend.get_payload(query_input, args) target_payload = {"prompt": prompt, "model": model, "max_tokens": output_length, @@ -25,8 +44,76 @@ def test_get_payload(): "stream_options": {"include_usage": True} } assert payload == target_payload + + +def test_backend_happy_vllm_test_chunk_validity(): + backend = backends.get_backend("happy_vllm") + for chunk in ['first chunk', "second chunk", "tok-i 1", "tok-o 1"]: + chunk_validity = backend.test_chunk_validity(chunk) + assert chunk_validity + + +def test_backend_happy_vllm_get_completions_headers(): + backend = backends.get_backend("happy_vllm") + assert backend.get_completions_headers() == {} + + +def test_backend_happy_vllm_remove_response_prefix(): + backend = backends.get_backend("happy_vllm") + for chunk_str in ["Hey", "How are you ?", "Fine thanks!"]: + chunk = "data: " + chunk_str + chunk = backend.remove_response_prefix(chunk) + assert chunk == chunk_str + + +def test_backend_happy_vllm_check_end_of_stream(): + backend = backends.get_backend("happy_vllm") + for chunk_not_done in ["not done", " [DONE]", "[DONE] "]: + assert not(backend.check_end_of_stream(chunk_not_done)) + assert backend.check_end_of_stream("[DONE]") + + +def test_backend_happy_vllm_get_newly_generated_text(): + generated_text = "Hey ! How are you ?" + json_chunk_completions = {"id": "cmpl-d85f82039b864ceb8d95be931b200745", "object": "chat.completion.chunk", + "created": 1716468615, "model": "CodeLlama-34B-AWQ", + "choices": [{"index": 0,"text": f"{generated_text}", + "stop_reason": None, "logprobs": None, "finish_reason": None}]} - # mistral backend + backend = backends.get_backend("happy_vllm") + newly_generated_text = backend.get_newly_generated_text(json_chunk_completions) + assert newly_generated_text == generated_text + + +def test_backend_happy_vllm_add_prompt_length(): + backend = backends.get_backend("happy_vllm") + prompt_tokens = 1234 + chunk_with_usage = {"id": "cmpl-d5edc7c2c3264f189b3c941630751d8e", + "object": "text_completion", + "created": 1722328225, + "model": "Vigostral-7B-Chat-AWQ", + "choices":[], + "usage": {"prompt_tokens": prompt_tokens,"total_tokens": 108,"completion_tokens": 86}} + + chunk_without_usage = {"id": "cmpl-d5edc7c2c3264f189b3c941630751d8e", + "object": "text_completion", + "created": 1722328225, + "model": "Vigostral-7B-Chat-AWQ", + "choices":[]} + # with usage key + output = QueryOutput() + assert output.prompt_length == 0 + backend.add_prompt_length(chunk_with_usage, output) + assert output.prompt_length == prompt_tokens + + output = QueryOutput() + assert output.prompt_length == 0 + backend.add_prompt_length(chunk_without_usage, output) + assert output.prompt_length == 0 + + +def test_backend_mistral_get_payload(): + backend = backends.get_backend("mistral") prompts_list = ["Hey. How are you?", "Fine, you ?"] model_list = ["My_awesome_model", "yet_another_model"] output_length_list = [100, 1000] @@ -34,8 +121,8 @@ def test_get_payload(): for model in model_list: for output_length in output_length_list: query_input = QueryInput(prompt=prompt, internal_id=0) - args = argparse.Namespace(backend="mistral", model=model, output_length=output_length) - payload = backends.get_payload(query_input, args) + args = argparse.Namespace(model=model, output_length=output_length) + payload = backend.get_payload(query_input, args) target_payload = {"messages": [{"role": "user", "content": prompt}], "model": model, "max_tokens": output_length, @@ -45,187 +132,77 @@ def test_get_payload(): } assert payload == target_payload - # ValueError - prompts_list = ["Hey. How are you?", "Fine, you ?"] - model_list = ["My_awesome_model", "yet_another_model"] - output_length_list = [100, 1000] - for prompt in prompts_list: - for model in model_list: - for output_length in output_length_list: - query_input = QueryInput(prompt=prompt, internal_id=0) - args = argparse.Namespace(backend="not_implemented_backend", model=model, output_length=output_length) - with pytest.raises(ValueError): - backends.get_payload(query_input, args) - -def test_test_chunk_validity(): - # happy_vllm backend - args = argparse.Namespace(backend="happy_vllm") - for chunk in ['first chunk', "second chunk", "tok-i 1", "tok-o 1"]: - chunk_validity = backends.test_chunk_validity(chunk, args) - assert chunk_validity - - # mistral backend - args = argparse.Namespace(backend="mistral") +def test_backend_mistral_test_chunk_validity(): + backend = backends.get_backend("mistral") # True for chunk in ['first chunk', "second chunk"]: - chunk_validity = backends.test_chunk_validity(chunk, args) + chunk_validity = backend.test_chunk_validity(chunk) assert chunk_validity # False for chunk in ["tok-i 1", "tok-o 1"]: - chunk_validity = backends.test_chunk_validity(chunk, args) + chunk_validity = backend.test_chunk_validity(chunk) assert not(chunk_validity) - # ValueError - args = argparse.Namespace(backend="not_implemented_backend") - for chunk in ['first chunk', "second chunk", "tok-i 1", "tok-o 1"]: - with pytest.raises(ValueError): - backends.test_chunk_validity(chunk, args) - -def test_get_completions_headers(): - # happy_vllm backend - args = argparse.Namespace(backend="happy_vllm") - assert backends.get_completions_headers(args) == {} +def test_backend_mistral_get_completions_headers(): + backend = backends.get_backend("mistral") + target_headers = {"Accept": "application/json", + "Content-Type": "application/json"} + assert backend.get_completions_headers() == target_headers - # mistral backend - args = argparse.Namespace(backend="mistral") - assert backends.get_completions_headers(args) == {"Accept": "application/json", - "Content-Type": "application/json"} - - # ValueError - args = argparse.Namespace(backend="not_implemented_backend") - with pytest.raises(ValueError): - backends.get_completions_headers(args) - - -def test_decode_remove_response_prefix(): - # happy_vllm backend - args = argparse.Namespace(backend="happy_vllm") - for chunk_str in ["Hey", "How are you ?", "Fine thanks!"]: - chunk_bytes = bytes("data: " + chunk_str, "utf-8") - chunk = backends.decode_remove_response_prefix(chunk_bytes, args) - assert chunk == chunk_str - # mistral backend - args = argparse.Namespace(backend="mistral") +def test_backend_mistral_remove_response_prefix(): + backend = backends.get_backend("mistral") for chunk_str in ["Hey", "How are you ?", "Fine thanks!"]: - chunk_bytes = bytes("data: " + chunk_str, "utf-8") - chunk = backends.decode_remove_response_prefix(chunk_bytes, args) + chunk = "data: " + chunk_str + chunk = backend.remove_response_prefix(chunk) assert chunk == chunk_str - # ValueError - args = argparse.Namespace(backend="not_implemented_backend") - for chunk_str in ["Hey", "How are you ?", "Fine thanks!"]: - chunk_bytes = bytes("data: " + chunk_str, "utf-8") - with pytest.raises(ValueError): - backends.decode_remove_response_prefix(chunk_bytes, args) - -def test_check_end_of_stream(): - # happy_vllm backend - args = argparse.Namespace(backend="happy_vllm") +def test_backend_mistral_check_end_of_stream(): + backend = backends.get_backend("mistral") for chunk_not_done in ["not done", " [DONE]", "[DONE] "]: - assert not(backends.check_end_of_stream(chunk_not_done, args)) - assert backends.check_end_of_stream("[DONE]", args) - - # mistral backend - args = argparse.Namespace(backend="mistral") - for chunk_not_done in ["not done", " [DONE]", "[DONE] "]: - assert not(backends.check_end_of_stream(chunk_not_done, args)) - assert backends.check_end_of_stream("[DONE]", args) - # ValueError - args = argparse.Namespace(backend="not_implemented_backend") - for chunk in ["not done", " [DONE]", "[DONE] ", "[DONE]"]: - with pytest.raises(ValueError): - backends.check_end_of_stream(chunk, args) + assert not(backend.check_end_of_stream(chunk_not_done)) + assert backend.check_end_of_stream("[DONE]") -def test_get_newly_generated_text(): +def test_backend_mistral_get_newly_generated_text(): generated_text = "Hey ! How are you ?" - json_chunk_completions = {"id": "cmpl-d85f82039b864ceb8d95be931b200745", "object": "chat.completion.chunk", - "created": 1716468615, "model": "CodeLlama-34B-AWQ", - "choices": [{"index": 0,"text": f"{generated_text}", - "stop_reason": None, "logprobs": None, "finish_reason": None}]} json_chunk_chat_completions = {"id":"cbaa5c28166d4b98b5256f1becc0364d", "object":"chat.completion.chunk", "created":1722322855, "model":"mistral", "choices":[{"index":0,"delta": {"content": f"{generated_text}"},"finish_reason": None,"logprobs": None}]} - - # happy_vllm backend - args = argparse.Namespace(backend="happy_vllm") - newly_generated_text = backends.get_newly_generated_text(json_chunk_completions, args) - assert newly_generated_text == generated_text - # mistral backend - args = argparse.Namespace(backend="mistral") - newly_generated_text = backends.get_newly_generated_text(json_chunk_chat_completions, args) + backend = backends.get_backend("mistral") + newly_generated_text = backend.get_newly_generated_text(json_chunk_chat_completions) assert newly_generated_text == generated_text - # ValueError - args = argparse.Namespace(backend="not_implemented_backend") - with pytest.raises(ValueError): - backends.get_newly_generated_text(json_chunk_chat_completions, args) - -def test_add_prompt_length(): +def test_backend_mistral_add_prompt_length(): prompt_tokens = 1234 - happy_vllm_chunk_with_usage = {"id": "cmpl-d5edc7c2c3264f189b3c941630751d8e", - "object": "text_completion", - "created": 1722328225, - "model": "Vigostral-7B-Chat-AWQ", - "choices":[], - "usage": {"prompt_tokens": prompt_tokens,"total_tokens": 108,"completion_tokens": 86}} - - happy_vllm_chunk_without_usage = {"id": "cmpl-d5edc7c2c3264f189b3c941630751d8e", - "object": "text_completion", - "created": 1722328225, - "model": "Vigostral-7B-Chat-AWQ", - "choices":[]} - - mistral_json_chunk_with_usage = {"id":"cbaa5c28166d4b98b5256f1becc0364d", + backend = backends.get_backend("mistral") + json_chunk_with_usage = {"id":"cbaa5c28166d4b98b5256f1becc0364d", "object": "chat.completion.chunk", "created":1722322855, "model":"mistral", "choices":[{"index": 0,"delta": {"content":""}, "finish_reason": "stop", "logprobs": None}], "usage": {"prompt_tokens": prompt_tokens,"total_tokens": 58,"completion_tokens": 46}} - mistral_json_chunk_without_usage = {"id":"cbaa5c28166d4b98b5256f1becc0364d", + json_chunk_without_usage = {"id":"cbaa5c28166d4b98b5256f1becc0364d", "object": "chat.completion.chunk", "created":1722322855, "model":"mistral", "choices":[{"index": 0,"delta": {"content":""}, "finish_reason": "stop", "logprobs": None}]} - # happy_vllm backend with usage key - args = argparse.Namespace(backend="happy_vllm") + #with usage key output = QueryOutput() assert output.prompt_length == 0 - backends.add_prompt_length(happy_vllm_chunk_with_usage, output, args) + backend.add_prompt_length(json_chunk_with_usage, output) assert output.prompt_length == prompt_tokens - # happy_vllm backend without usage key - args = argparse.Namespace(backend="happy_vllm") - output = QueryOutput() - assert output.prompt_length == 0 - backends.add_prompt_length(happy_vllm_chunk_without_usage, output, args) - assert output.prompt_length == 0 - - # mistral backend with usage key - args = argparse.Namespace(backend="mistral") + #without usage key output = QueryOutput() assert output.prompt_length == 0 - backends.add_prompt_length(mistral_json_chunk_with_usage, output, args) - assert output.prompt_length == prompt_tokens - - # mistral backend without usage key - args = argparse.Namespace(backend="mistral") - output = QueryOutput() - assert output.prompt_length == 0 - backends.add_prompt_length(mistral_json_chunk_without_usage, output, args) - assert output.prompt_length == 0 - - # ValueError - args = argparse.Namespace(backend="not_implemented_backend") - output = QueryOutput() - with pytest.raises(ValueError): - backends.add_prompt_length(mistral_json_chunk_with_usage, output, args) \ No newline at end of file + backend.add_prompt_length(json_chunk_without_usage, output) + assert output.prompt_length == 0 \ No newline at end of file diff --git a/tests/test_query_functions.py b/tests/test_query_functions.py index dc502a7..0df3524 100644 --- a/tests/test_query_functions.py +++ b/tests/test_query_functions.py @@ -3,18 +3,19 @@ import argparse from aioresponses import aioresponses +from benchmark_llm_serving.backends import get_backend from benchmark_llm_serving.query_profiles import query_functions from benchmark_llm_serving.io_classes import QueryInput, QueryOutput -async def mock_streaming_response_content(args): +async def mock_streaming_response_content(args, backend): tokens_list = ["Hey", " how", " are", " you", " ?"] * 1000 tokens_list = tokens_list[:args.output_length] - if args.backend == "happy_vllm": + if backend.backend_name == "happy_vllm": for element in iter(tokens_list): string = f"""data: {{"id":"cmpl-d85f82039b864ceb8d95be931b200745", "object":"chat.completion.chunk", "created":1716468615, "model":"CodeLlama-34B-AWQ", "choices":[{{"index":0,"text":"{element}","stop_reason":null,"logprobs":null,"finish_reason":null}}]}}""" yield bytes(string, "utf-8") - if args.backend == "mistral": + if backend.backend_name == "mistral": for element in iter(tokens_list): string = f"""data: {{"id":"cbaa5c28166d4b98b5256f1becc0364d","object":"chat.completion.chunk","created":1722322855,"model":"mistral","choices":[{{"index":0,"delta":{{"content":"{element}"}},"finish_reason":null,"logprobs":null}}]}}""" yield bytes(string, "utf-8") @@ -22,17 +23,18 @@ async def mock_streaming_response_content(args): class MockSession(): - def __init__(self, args): + def __init__(self, args, backend): self.args = args + self.backend = backend def post(self, **kwargs): - return MockResponse(self.args) + return MockResponse(self.args, self.backend) class MockResponse(): - def __init__(self, args): + def __init__(self, args, backend): self.status = 200 - self.content = mock_streaming_response_content(args) + self.content = mock_streaming_response_content(args, backend) async def __aenter__(self): return self @@ -44,44 +46,48 @@ async def __aexit__(self, exc_t, exc_v, exc_tb): async def test_query_function(): # First call output_length = 10 - args = argparse.Namespace(output_length=output_length, model="CodeLlama-34B-AWQ", backend="happy_vllm") + backend = get_backend("happy_vllm") + args = argparse.Namespace(output_length=output_length, model="CodeLlama-34B-AWQ") prompt = "Hey !" - async_generator = mock_streaming_response_content(args) + async_generator = mock_streaming_response_content(args, backend) query_input = QueryInput(prompt=prompt, internal_id=0) results = [] - session = MockSession(args) + session = MockSession(args, backend) await query_functions.query_function(query_input, session, "my_url", results=results, - args=args) + args=args, backend=backend) assert len(results) == 1 assert isinstance(results[0], QueryOutput) + print(results) assert len(results[0].timestamp_of_tokens_arrival) == output_length # Another call output_length = 5 - args = argparse.Namespace(output_length=output_length, model="CodeLlama-34B-AWQ", backend="happy_vllm") + backend = get_backend("happy_vllm") + args = argparse.Namespace(output_length=output_length, model="CodeLlama-34B-AWQ") prompt = "Hey !" - async_generator = mock_streaming_response_content(args) + async_generator = mock_streaming_response_content(args, backend) query_input = QueryInput(prompt=prompt, internal_id=1) - session = MockSession(args) + session = MockSession(args, backend) await query_functions.query_function(query_input, session, "my_url", results=results, - args=args) + args=args, backend=backend) assert len(results) == 2 assert isinstance(results[1], QueryOutput) assert len(results[1].timestamp_of_tokens_arrival) == output_length # mistral backend output_length = 10 - args = argparse.Namespace(output_length=output_length, model="mistral", backend="mistral") + backend = get_backend("mistral") + args = argparse.Namespace(output_length=output_length, model="mistral") prompt = "Hey !" - async_generator = mock_streaming_response_content(args) + async_generator = mock_streaming_response_content(args, backend) query_input = QueryInput(prompt=prompt, internal_id=0) results = [] - session = MockSession(args) + session = MockSession(args, backend) await query_functions.query_function(query_input, session, "my_url", results=results, - args=args) + args=args, backend=backend) assert len(results) == 1 assert isinstance(results[0], QueryOutput) assert len(results[0].timestamp_of_tokens_arrival) == output_length \ No newline at end of file