From 4967c815a2a7bb10571e48920d3a1beec75dd48e Mon Sep 17 00:00:00 2001 From: "Tianyi (Alex) Qiu" Date: Thu, 5 Dec 2024 22:46:47 -0800 Subject: [PATCH] feat(abstractions): continuous backend + remove default max_tokens limit --- __init__.py | 2 + src/abstractions/backends.py | 17 +++++++-- src/abstractions/configs/templates_configs.py | 38 ++++++++++++++++++- 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/__init__.py b/__init__.py index 8a293b5..c30afb7 100644 --- a/__init__.py +++ b/__init__.py @@ -17,6 +17,7 @@ from src.abstractions.model import Model, fill_in_QA_template from src.abstractions.data import Data, DataFileCollection +from src.abstractions.configs.templates_configs import GlobalState __all__ = [ "run_benchmark", @@ -35,4 +36,5 @@ "ExtrapolativeDPOExaminee", "ExtrapolativeRLHFExaminee", "fill_in_QA_template", + "GlobalState", ] diff --git a/src/abstractions/backends.py b/src/abstractions/backends.py index 94f37e6..d78daac 100644 --- a/src/abstractions/backends.py +++ b/src/abstractions/backends.py @@ -33,6 +33,7 @@ import tqdm from transformers import AutoTokenizer import random +from src.abstractions.configs.templates_configs import GlobalState # create output directories os.makedirs(f"{root}/output/benchmark_results", exist_ok=True) @@ -259,6 +260,11 @@ def start_inference_backend( if purpose == "logprobs": raise ValueError("VLLM backend does not support logprobs purpose.") + if GlobalState.continuous_backend: + warnings.warn( + "VLLM backend does not support continuous_backend=True. Ignoring this setting." + ) + LLM, SamplingParams, destroy_model_parallel = import_from_vllm() if template_type == "auto": @@ -281,7 +287,7 @@ def start_inference_backend( ) def vllm_process_batch( - sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = 1024 + sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = None ) -> List[dict]: nonlocal template_type sampling_params = SamplingParams( @@ -477,7 +483,7 @@ def vllm_free_gpu_memory(): @sgl.function def get_response( - s, conversation: List, temperature: float = 0.2, max_tokens: int = 256, options: list = [] + s, conversation: List, temperature: float = 0.2, max_tokens: int = None, options: list = [] ) -> str: nonlocal purpose @@ -511,7 +517,7 @@ def get_response( ) def sglang_process_batch( - sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = 256 + sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = None ) -> List[dict]: """Process a batch of samples using the sglang backend. When purpose is "logprobs", it will return the log probability of the prompt text itself, without generating any text. The probability will be stored in the "logprob" field of the output dictionary, with all other fields staying the same. @@ -644,6 +650,11 @@ def sglang_free_gpu_memory(): os.kill(process.pid, signal.SIGINT) os.kill(process.pid, signal.SIGKILL) + if GlobalState.continuous_backend: + # If continuous_backend=True, keep the backend alive for reuse until the exit of the context manager. + GlobalState.__active_backend_destroyers.append(sglang_free_gpu_memory) + return backend, sglang_process_batch, (lambda: None) + return backend, sglang_process_batch, sglang_free_gpu_memory raise ValueError(f"Backend type {backend_type} not recognized.") diff --git a/src/abstractions/configs/templates_configs.py b/src/abstractions/configs/templates_configs.py index 9d3ea99..6727350 100644 --- a/src/abstractions/configs/templates_configs.py +++ b/src/abstractions/configs/templates_configs.py @@ -2,7 +2,43 @@ from string import Template import json import os -from typing import Dict, Any, Literal, Optional, List, Union +from typing import Dict, Any, Literal, Optional, List, Union, Callable + + +class GlobalState: + + # Public variables + continuous_backend: bool = False + + # Private variables + __active_backend_destroyers: List[Callable[[], None]] = [] + + def __init__(self, **kwargs: Dict[str, Any]): + """ + Temporarily set global state variables of ProgressGym for the duration of a context manager block. + + Example: + ``` + with GlobalState(continuous_backend=True): + # code that requires continuous backends (i.e. backend does not die after each call to inference, and it kept alive for reuse) + ``` + """ + self.kwargs = kwargs + + def __enter__(self): + self.prior_state = {k: getattr(GlobalState, k) for k in self.kwargs.keys()} + for k, v in self.kwargs.items(): + setattr(GlobalState, k, v) + + def __exit__(self, type, value, traceback): + for k, v in self.prior_state.items(): + setattr(GlobalState, k, v) + + for destroy in GlobalState.__active_backend_destroyers: + destroy() + + GlobalState.__active_backend_destroyers = [] + bash_command_template = """PYTHONNOUSERSITE=1 MASTER_PORT=9902 conda run --no-capture-output -n %s deepspeed %s --master_port=9902 ./libs/llama_factory/src/train_bash.py \\ --deepspeed %s \\