Skip to content

Commit

Permalink
feat: inference control over max_tokens and temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyiQ committed Dec 6, 2024
1 parent 53ecb79 commit 82cff8a
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from src.abstractions.configs.templates_configs import *
import multiprocessing
from src.download_models import download_model
from functools import partial

from src.abstractions.backends import (
start_inference_backend,
Expand All @@ -41,6 +42,7 @@ def inference_standalone(
prompt_field_name: str,
query_field_name: str,
temperature: float,
max_tokens: int,
backend_type: Literal["sglang", "vllm"],
purpose: Literal["responses", "logprobs"],
conn: multiprocessing.connection.Connection,
Expand All @@ -58,7 +60,7 @@ def inference_standalone(
prompt_field_name=prompt_field_name, query_field_name=query_field_name
)
result_data = data.transform(
transformation=process_batch,
transformation=partial(process_batch, temperature=temperature, max_tokens=max_tokens),
result_data_name=result_data_name,
forced_rewrite=(
Model.always_force_rewrite
Expand Down Expand Up @@ -666,7 +668,8 @@ def inference(
result_data_name: str,
backend: Literal["sglang", "vllm", "deepspeed", "serial"] = "sglang",
batch_size_multiplier_log2: int = 0,
temperature=0.25,
temperature: float = 0.25,
max_tokens: int = 8192,
purpose: Literal["responses", "logprobs"] = "responses",
) -> Union[Data, List[Dict[str, str]]]:
"""Performance inference on a dataset (currently only instruction datasets are tested, with the same format as SFT datasets),
Expand All @@ -684,9 +687,12 @@ def inference(
:param batch_size_multiplier_log2: The log base 2 of the batch size multiplier
:type batch_size_multiplier_log2: int = 0
:param temperature: The temperature parameter
:param temperature: The temperature parameter.
:type temperature: float = 0.25
:param max_tokens: The maximum number of tokens to generate. Ignored if purpose is "logprobs".
:type max_tokens: int = 8192
:param purpose: The purpose of the inference. It can be "responses" or "logprobs". If "logprobs", the log probability of the prompt itself (and the assistant response supplied in the `predict` field, if exists) is returned in the `logprob` field of the resulting dataset, without doing any completion. If "responses", the completion text is saved in the `predict` field of the resulting dataset.
:type purpose: Literal["responses", "logprobs"] = "responses"
Expand Down Expand Up @@ -753,7 +759,7 @@ def inference(

result = (
self.__inference_parallel_segregated(
data, result_data_name, temperature, backend, purpose
data, result_data_name, temperature, max_tokens, backend, purpose
)
if backend in ["vllm", "sglang"]
else self.__inference_parallel_deepspeed(
Expand Down Expand Up @@ -786,7 +792,7 @@ def inference(
return result

def __inference_parallel_segregated(
self, data: Data, result_data_name: str, temperature: float, backend_type: str, purpose: str
self, data: Data, result_data_name: str, temperature: float, max_tokens: int, backend_type: str, purpose: str
) -> Data:
"""sglang/vllm implementation for `inference()`, but performed in a separate process to free up GPU memory. This is the recommended implementation, due to its superior speed and robustness."""
data_path = data.data_path
Expand Down Expand Up @@ -814,6 +820,7 @@ def __inference_parallel_segregated(
prompt_field_name,
query_field_name,
temperature,
max_tokens,
backend_type,
purpose,
child_conn,
Expand Down

0 comments on commit 82cff8a

Please sign in to comment.