Skip to content

Commit

Permalink
Merge pull request #24 from TianyiQ/main
Browse files Browse the repository at this point in the history
feat(abstractions): continuous backend + remove default max_tokens limit
  • Loading branch information
TianyiQ authored Dec 6, 2024
2 parents 5ff60d9 + 7fc8985 commit 53ecb79
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
2 changes: 2 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from src.abstractions.model import Model, fill_in_QA_template
from src.abstractions.data import Data, DataFileCollection
from src.abstractions.configs.templates_configs import GlobalState

__all__ = [
"run_benchmark",
Expand All @@ -35,4 +36,5 @@
"ExtrapolativeDPOExaminee",
"ExtrapolativeRLHFExaminee",
"fill_in_QA_template",
"GlobalState",
]
17 changes: 14 additions & 3 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import tqdm
from transformers import AutoTokenizer
import random
from src.abstractions.configs.templates_configs import GlobalState

# create output directories
os.makedirs(f"{root}/output/benchmark_results", exist_ok=True)
Expand Down Expand Up @@ -259,6 +260,11 @@ def start_inference_backend(
if purpose == "logprobs":
raise ValueError("VLLM backend does not support logprobs purpose.")

if GlobalState.continuous_backend:
warnings.warn(
"VLLM backend does not support continuous_backend=True. Ignoring this setting."
)

LLM, SamplingParams, destroy_model_parallel = import_from_vllm()

if template_type == "auto":
Expand All @@ -281,7 +287,7 @@ def start_inference_backend(
)

def vllm_process_batch(
sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = 1024
sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = None
) -> List[dict]:
nonlocal template_type
sampling_params = SamplingParams(
Expand Down Expand Up @@ -477,7 +483,7 @@ def vllm_free_gpu_memory():

@sgl.function
def get_response(
s, conversation: List, temperature: float = 0.2, max_tokens: int = 256, options: list = []
s, conversation: List, temperature: float = 0.2, max_tokens: int = None, options: list = []
) -> str:
nonlocal purpose

Expand Down Expand Up @@ -511,7 +517,7 @@ def get_response(
)

def sglang_process_batch(
sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = 256
sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = None
) -> List[dict]:
"""Process a batch of samples using the sglang backend.
When purpose is "logprobs", it will return the log probability of the prompt text itself, without generating any text. The probability will be stored in the "logprob" field of the output dictionary, with all other fields staying the same.
Expand Down Expand Up @@ -644,6 +650,11 @@ def sglang_free_gpu_memory():
os.kill(process.pid, signal.SIGINT)
os.kill(process.pid, signal.SIGKILL)

if GlobalState.continuous_backend:
# If continuous_backend=True, keep the backend alive for reuse until the exit of the context manager.
GlobalState.__active_backend_destroyers.append(sglang_free_gpu_memory)
return backend, sglang_process_batch, (lambda: None)

return backend, sglang_process_batch, sglang_free_gpu_memory

raise ValueError(f"Backend type {backend_type} not recognized.")
Expand Down
38 changes: 37 additions & 1 deletion src/abstractions/configs/templates_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,43 @@
from string import Template
import json
import os
from typing import Dict, Any, Literal, Optional, List, Union
from typing import Dict, Any, Literal, Optional, List, Union, Callable


class GlobalState:

# Public variables
continuous_backend: bool = False

# Private variables
__active_backend_destroyers: List[Callable[[], None]] = []

def __init__(self, **kwargs: Dict[str, Any]):
"""
Temporarily set global state variables of ProgressGym for the duration of a context manager block.
Example:
```
with GlobalState(continuous_backend=True):
# code that requires continuous backends (i.e. backend does not die after each call to inference, and it kept alive for reuse)
```
"""
self.kwargs = kwargs

def __enter__(self):
self.prior_state = {k: getattr(GlobalState, k) for k in self.kwargs.keys()}
for k, v in self.kwargs.items():
setattr(GlobalState, k, v)

def __exit__(self, type, value, traceback):
for k, v in self.prior_state.items():
setattr(GlobalState, k, v)

for destroy in GlobalState.__active_backend_destroyers:
destroy()

GlobalState.__active_backend_destroyers = []


bash_command_template = """PYTHONNOUSERSITE=1 MASTER_PORT=9902 conda run --no-capture-output -n %s deepspeed %s --master_port=9902 ./libs/llama_factory/src/train_bash.py \\
--deepspeed %s \\
Expand Down

0 comments on commit 53ecb79

Please sign in to comment.