Skip to content

Commit

Permalink
fix max tokens (and optimize imports) (#37)
Browse files Browse the repository at this point in the history
* accept max_tokens

* settle on min/max_tokens to match openai

* fix type errors

* fix descriptions
  • Loading branch information
technillogue authored May 2, 2024
1 parent 384a39e commit 7c96b42
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 54 deletions.
119 changes: 71 additions & 48 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
# Prediction interface for Cog ⚙️
import asyncio
import json
import os
import subprocess
import httpx
from cog import BasePredictor, ConcatenateIterator, Input
import time
import json
import multiprocessing as mp
from typing import Optional

from sse import receive_sse
from triton_config_generator import generate_configs, load_yaml_config

import pytriton.utils.distribution

TRITONSERVER_DIST_DIR = (
pytriton.utils.distribution.get_root_module_path() / "tritonserver"
)
TRITONSERVER_BACKEND_DIR = os.getenv(
"TRITONSERVER_BACKEND_DIR", str(TRITONSERVER_DIST_DIR / "backends")
)
from cog import BasePredictor, ConcatenateIterator, Input

import numpy as np
if mp.current_process().name != "MainProcess":
import httpx
import numpy as np
import pytriton.utils.distribution
from transformers import AutoTokenizer

from utils import (
maybe_download_tarball_with_pget,
StreamingTokenStopSequenceHandler,
)
from sse import receive_sse
from triton_config_generator import generate_configs, load_yaml_config
from utils import StreamingTokenStopSequenceHandler, maybe_download_tarball_with_pget

from transformers import AutoTokenizer
TRITONSERVER_DIST_DIR = (
pytriton.utils.distribution.get_root_module_path() / "tritonserver"
)
TRITONSERVER_BACKEND_DIR = os.getenv(
"TRITONSERVER_BACKEND_DIR", str(TRITONSERVER_DIST_DIR / "backends")
)


class Predictor(BasePredictor):
Expand Down Expand Up @@ -67,8 +64,9 @@ async def setup(self, weights: str = "") -> None:
self.trt_llm_config = config = json.load(f)
print(f"tensorrt_llm config: {config}")

if os.getenv("MAX_SEQUENCE_LENGTH", None):
self.max_sequence_length = int(os.getenv("MAX_SEQUENCE_LENGTH"))
max_seqlen_env = os.getenv("MAX_SEQUENCE_LENGTH", None)
if max_seqlen_env :
self.max_sequence_length = int(max_seqlen_env )
else:
try:
self.max_sequence_length = self.trt_llm_config["pretrained_config"][
Expand All @@ -91,7 +89,7 @@ async def setup(self, weights: str = "") -> None:
return
raise Exception(f"Couldn't start Triton (exit code {self.proc.poll()})")

async def start_triton(self) -> None:
async def start_triton(self) -> bool:
# # launch triton server
# # python3 scripts/launch_triton_server.py --world_size=1 --model_repo=/src/tensorrtllm_backend/triton_model
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand Down Expand Up @@ -138,12 +136,12 @@ async def predict(
description="System prompt to send to the model. This is prepended to the prompt and helps guide system behavior.",
default=os.getenv("SYSTEM_PROMPT", ""),
),
max_new_tokens: int = Input(
description="Maximum number of tokens to generate. A word is generally 2-3 tokens",
max_tokens: int = Input(
description="Maximum number of tokens to generate. A word is generally 2-3 tokens.",
ge=1,
default=128,
default=512,
),
min_new_tokens: int = Input(
min_tokens: int = Input(
description="Minimum number of tokens to generate. To disable, set to -1. A word is generally 2-3 tokens.",
ge=-1,
default=None,
Expand All @@ -155,13 +153,13 @@ async def predict(
default=0.7,
),
top_p: float = Input(
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens.",
ge=0.0,
le=1.0,
default=0.95,
),
top_k: int = Input(
description="When decoding text, samples from the top k most likely tokens; lower to ignore less likely tokens",
description="When decoding text, samples from the top k most likely tokens; lower to ignore less likely tokens.",
ge=-1,
default=0,
),
Expand All @@ -180,14 +178,24 @@ async def predict(
default=0.0,
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed",
description="Random seed. Leave blank to randomize the seed.",
default=None,
),
prompt_template: str = Input(
description="Template for formatting the prompt. Can be an arbitrary string, but must contain the substring `{prompt}`.",
default=os.getenv("PROMPT_TEMPLATE", "{prompt}"),
),
log_performance_metrics: bool = False,
max_new_tokens: int = Input(
description="This parameter has been renamed to max_tokens. max_new_tokens only exists for backwards compatibility purposes. We recommend you use max_tokens instead. Both may not be specified.",
ge=1,
default=None,
),
min_new_tokens: int = Input(
description="This parameter has been renamed to min_tokens. min_new_tokens only exists for backwards compatibility purposes. We recommend you use min_tokens instead. Both may not be specified.",
ge=-1,
default=None,
),
) -> ConcatenateIterator:
if not self.model_exists:
self.log(
Expand All @@ -201,10 +209,23 @@ async def predict(
if formatted_prompt == "":
raise Exception("A prompt is required, but your formatted prompt is blank")

# compatibility with older language models
if max_new_tokens:
# 512 is the default
if max_tokens == 512 or max_tokens is None:
max_tokens = max_new_tokens
else:
raise Exception(f"Can't set both max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens})")
if min_new_tokens:
if min_tokens is None:
min_tokens = min_new_tokens
else:
raise Exception(f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})")

args = self._process_args(
prompt=formatted_prompt,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
max_tokens=max_tokens,
min_tokens=min_tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature,
Expand All @@ -229,6 +250,7 @@ async def predict(
start_time = time.time()
n_tokens = 0
tokens = np.array([], dtype=np.int32)
first_token_time = None

async with req as resp:
async for event in receive_sse(resp):
Expand Down Expand Up @@ -270,13 +292,14 @@ async def predict(
if self.log_performance_metrics or log_performance_metrics:
latency = end_time - start_time
actual_tps = n_tokens / latency
time_to_first_token = first_token_time - start_time
self.log(f"Tokens processed: {n_tokens}\n")
self.log(f"Serverside tokens per second: {round(actual_tps, 2)}\n")
self.log(f"Serverside execution time: {round(latency, 2)} seconds\n")
self.log(
f"Serverside time to first token: {round(time_to_first_token, 2)} seconds\n"
)
if first_token_time:
time_to_first_token = first_token_time - start_time
self.log(
f"Serverside time to first token: {round(time_to_first_token, 2)} seconds\n"
)

self.log(f"Random seed used: `{args['random_seed']}`\n")
self.log(
Expand All @@ -293,27 +316,27 @@ async def predict(
def _process_args(
self,
prompt: str,
max_new_tokens: int = 250,
min_new_tokens: int = None,
max_tokens: int = 250,
min_tokens: Optional[int] = None,
top_k: int = 0,
top_p: float = 0.0,
temperature: float = 1.0,
length_penalty: float = 1.0,
presence_penalty: float = 0.0,
stop_words: str = None,
seed: int = None,
stop_words: Optional[str] = None,
seed: Optional[int] = None,
stream: bool = True,
):
stop_words_list = stop_words.split(",") if stop_words else []
min_new_tokens = 0 if min_new_tokens is None else min_new_tokens
min_tokens = 0 if min_tokens is None else min_tokens

pad_id = self.pad_id
end_id = self.end_id

if top_k < 0:
top_k = 0
if min_new_tokens < 0:
min_new_tokens = 0
if min_tokens < 0:
min_tokens = 0

if not seed:
seed = int(np.random.randint(0, 100000))
Expand All @@ -322,13 +345,13 @@ def _process_args(

if self.max_sequence_length:
token_budget = self.max_sequence_length - n_prompt_tokens
max_new_tokens = min(max_new_tokens, token_budget)
min_new_tokens = min(min_new_tokens, token_budget)
max_tokens = min(max_tokens, token_budget)
min_tokens = min(min_tokens, token_budget)

args = {
"text_input": prompt,
"max_tokens": max_new_tokens,
"min_length": min_new_tokens,
"max_tokens": max_tokens,
"min_length": min_tokens,
"top_k": top_k,
"temperature": temperature,
"top_p": top_p,
Expand Down
10 changes: 4 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
import subprocess
import requests
import time
import shutil
import subprocess
import sys
from pathlib import Path
import shutil
import time
import typing as tp
from collections import deque
from pathlib import Path

import requests

def maybe_download_tarball_with_pget(
url: str,
Expand Down

0 comments on commit 7c96b42

Please sign in to comment.