From 82cff8a1b65700caeb7b5a42abee8bd4a0ab1eee Mon Sep 17 00:00:00 2001 From: "Tianyi (Alex) Qiu" Date: Fri, 6 Dec 2024 00:35:35 -0800 Subject: [PATCH] feat: inference control over max_tokens and temperature --- src/abstractions/model.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/abstractions/model.py b/src/abstractions/model.py index d223274..41d565b 100644 --- a/src/abstractions/model.py +++ b/src/abstractions/model.py @@ -21,6 +21,7 @@ from src.abstractions.configs.templates_configs import * import multiprocessing from src.download_models import download_model +from functools import partial from src.abstractions.backends import ( start_inference_backend, @@ -41,6 +42,7 @@ def inference_standalone( prompt_field_name: str, query_field_name: str, temperature: float, + max_tokens: int, backend_type: Literal["sglang", "vllm"], purpose: Literal["responses", "logprobs"], conn: multiprocessing.connection.Connection, @@ -58,7 +60,7 @@ def inference_standalone( prompt_field_name=prompt_field_name, query_field_name=query_field_name ) result_data = data.transform( - transformation=process_batch, + transformation=partial(process_batch, temperature=temperature, max_tokens=max_tokens), result_data_name=result_data_name, forced_rewrite=( Model.always_force_rewrite @@ -666,7 +668,8 @@ def inference( result_data_name: str, backend: Literal["sglang", "vllm", "deepspeed", "serial"] = "sglang", batch_size_multiplier_log2: int = 0, - temperature=0.25, + temperature: float = 0.25, + max_tokens: int = 8192, purpose: Literal["responses", "logprobs"] = "responses", ) -> Union[Data, List[Dict[str, str]]]: """Performance inference on a dataset (currently only instruction datasets are tested, with the same format as SFT datasets), @@ -684,9 +687,12 @@ def inference( :param batch_size_multiplier_log2: The log base 2 of the batch size multiplier :type batch_size_multiplier_log2: int = 0 - :param temperature: The temperature parameter + :param temperature: The temperature parameter. :type temperature: float = 0.25 + :param max_tokens: The maximum number of tokens to generate. Ignored if purpose is "logprobs". + :type max_tokens: int = 8192 + :param purpose: The purpose of the inference. It can be "responses" or "logprobs". If "logprobs", the log probability of the prompt itself (and the assistant response supplied in the `predict` field, if exists) is returned in the `logprob` field of the resulting dataset, without doing any completion. If "responses", the completion text is saved in the `predict` field of the resulting dataset. :type purpose: Literal["responses", "logprobs"] = "responses" @@ -753,7 +759,7 @@ def inference( result = ( self.__inference_parallel_segregated( - data, result_data_name, temperature, backend, purpose + data, result_data_name, temperature, max_tokens, backend, purpose ) if backend in ["vllm", "sglang"] else self.__inference_parallel_deepspeed( @@ -786,7 +792,7 @@ def inference( return result def __inference_parallel_segregated( - self, data: Data, result_data_name: str, temperature: float, backend_type: str, purpose: str + self, data: Data, result_data_name: str, temperature: float, max_tokens: int, backend_type: str, purpose: str ) -> Data: """sglang/vllm implementation for `inference()`, but performed in a separate process to free up GPU memory. This is the recommended implementation, due to its superior speed and robustness.""" data_path = data.data_path @@ -814,6 +820,7 @@ def __inference_parallel_segregated( prompt_field_name, query_field_name, temperature, + max_tokens, backend_type, purpose, child_conn,