Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Update rolling batch params to output delta #2636

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lmi_dist.arg_utils import VllmEngineArgs
from lmi_dist.init_engine import engine_from_args
from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor
from vllm import SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import AtomicCounter

from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
Expand Down Expand Up @@ -140,6 +140,7 @@ def translate_lmi_dist_params(self, parameters: dict):

:return: The same parameters dict, but with lmi-dist style parameter names.
"""
parameters["output_kind"] = RequestOutputKind.DELTA
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
# If `do_sample` is not provided, force temperature=0.0, i.e. greedy
# else set to user-provided value or default to 1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def update_request_cache_with_output(request_cache: OrderedDict,
request_output.prompt_tokens_details.append(prompt_token)

# sets the details of all sequences
update_multiple_sequences(cache, request_output, vllm_request_output)
update_multiple_sequences(request_output, vllm_request_output)

# remove finished requests from cache
if vllm_request_output.finished:
Expand All @@ -89,49 +89,28 @@ def update_request_cache_with_output(request_cache: OrderedDict,
return request_cache


def update_multiple_sequences(cache, request_output, vllm_request_output):
def update_multiple_sequences(request_output, vllm_request_output):
for completion_output in vllm_request_output.outputs:

sequence_index = completion_output.index
if f"sequence_index_{sequence_index}" not in cache:
cache[f"sequence_index_{sequence_index}"] = {
"curr_length": 0,
"num_generated_tokens": 0
}

if sequence_index not in request_output.sequences:
request_output.sequences[sequence_index] = Sequence()

# set token of the sequence
# previous length of token ids generated
prev_len = cache[f"sequence_index_{sequence_index}"][
'num_generated_tokens']
# curr length of the token ids generated so far
cur_len = len(completion_output.token_ids)
cache[f"sequence_index_{sequence_index}"][
"num_generated_tokens"] = cur_len

# get the newly generated token_ids
new_token_ids = completion_output.token_ids[
prev_len:
cur_len] if prev_len < cur_len else completion_output.token_ids
new_token_ids = completion_output.token_ids

# get the newly generated token texts for speculative decoding
output_token_texts = []
if hasattr(completion_output, "output_token_texts"):
output_token_texts = completion_output.output_token_texts[
prev_len:
cur_len] if prev_len < cur_len else completion_output.output_token_texts
output_token_texts = completion_output.output_token_texts

top_tokens = []
token_texts = []
# calculate log probs and token_texts
if completion_output.logprobs:
new_logprobs_list = completion_output.logprobs[
prev_len:
cur_len] if prev_len < cur_len else completion_output.logprobs
new_logprobs = []
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
for token_id, logprobs in zip(new_token_ids,
completion_output.logprobs):
new_logprobs.append(logprobs[token_id].logprob)
decoded_token = logprobs[token_id].decoded_token if logprobs[
token_id].decoded_token else ""
Expand All @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
Token(id=token_id_key,
text=logprob.decoded_token,
log_prob=logprob.logprob))

elif new_token_ids:
# TODO: Test and remove this. logprobs is always set 1. This case should never happen.
new_logprobs = [None] * len(new_token_ids)
curr_length = cache[f"sequence_index_{sequence_index}"][
"curr_length"]
token_texts.append(completion_output.text[curr_length:])
token_texts.append(completion_output.text)

if not output_token_texts:
if len(token_texts) != len(new_token_ids):
Expand Down Expand Up @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
request_output.sequences[sequence_index].set_next_top_tokens(
top_tokens)

cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
completion_output.text)


def get_speculative_decoding_metrics_record(
completion_output: CompletionOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand Down Expand Up @@ -78,6 +79,7 @@ def translate_vllm_params(self, parameters: dict) -> dict:

:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters["output_kind"] = RequestOutputKind.DELTA
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
if "seed" in parameters.keys():
parameters["seed"] = int(parameters["seed"])
Expand Down
108 changes: 10 additions & 98 deletions engines/python/setup/djl_python/tests/test_rb_vllm_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import unittest
import uuid
from dataclasses import dataclass
from typing import List, Optional, Dict, Union
from collections import OrderedDict
Expand All @@ -12,8 +11,8 @@
import djl_python
from djl_python.output_formatter import _json_output_formatter
from djl_python.request import Request
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput
'''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1'''
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token
'''These Mock classes are in compliance with vllm RequestOutput version 0.6.3.post1'''


@dataclass
Expand Down Expand Up @@ -148,23 +147,10 @@ def __init__(
],
outputs=[
MockCompletionOutput(index=1,
text=' member of',
token_ids=[4292, 302],
text=' of',
token_ids=[302],
cumulative_logprob=-4.3041129764169455,
logprobs=[{
4292:
MockLogprob(logprob=-4.2740092277526855,
rank=4,
decoded_token=' member'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
302:
MockLogprob(logprob=-0.03010374866425991,
rank=1,
Expand All @@ -181,27 +167,10 @@ def __init__(
finish_reason=None,
stop_reason=None),
MockCompletionOutput(index=0,
text=' consolidated',
token_ids=[22968, 601],
text='ated',
token_ids=[601],
cumulative_logprob=-13.402491569519043,
logprobs=[{
22968:
MockLogprob(logprob=-12.117759704589844,
rank=5308,
decoded_token=' consolid'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
17372:
MockLogprob(logprob=-13.409988403320312,
rank=10489,
decoded_token=' crown'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
601:
MockLogprob(logprob=-1.2847318649291992,
rank=2,
Expand Down Expand Up @@ -235,37 +204,10 @@ def __init__(
],
outputs=[
MockCompletionOutput(index=1,
text=' member of the',
token_ids=[4292, 302,
272],
text=' the',
token_ids=[272],
cumulative_logprob=-4.815703457221389,
logprobs=[{
4292:
MockLogprob(logprob=-4.2740092277526855,
rank=4,
decoded_token=' member'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
302:
MockLogprob(logprob=-0.03010374866425991,
rank=1,
decoded_token=' of'),
235290:
MockLogprob(logprob=-2.2026185989379883,
rank=1,
decoded_token='-'),
578:
MockLogprob(logprob=-2.2026185989379883,
rank=2,
decoded_token=' and')
}, {
272:
MockLogprob(logprob=-0.5115904808044434,
rank=1,
Expand All @@ -282,40 +224,10 @@ def __init__(
finish_reason='length',
stop_reason=None),
MockCompletionOutput(index=0,
text=' consolidated or',
token_ids=[22968, 601, 442],
text=' or',
token_ids=[442],
cumulative_logprob=-20.4010648727417,
logprobs=[{
22968:
MockLogprob(logprob=-12.117759704589844,
rank=5308,
decoded_token=' consolid'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
17372:
MockLogprob(logprob=-13.409988403320312,
rank=10489,
decoded_token=' crown'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
601:
MockLogprob(logprob=-1.2847318649291992,
rank=2,
decoded_token='ated'),
1028:
MockLogprob(logprob=-0.909731924533844,
rank=1,
decoded_token='ator'),
1162:
MockLogprob(logprob=-0.8929234743118286,
rank=2,
decoded_token=' year')
}, {
442:
MockLogprob(logprob=-6.998573303222656,
rank=188,
Expand Down
Loading