Skip to content

Commit

Permalink
Fix decoder_input_details bug (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Dec 10, 2024
1 parent 63c5eb3 commit 2af302d
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 33 deletions.
11 changes: 1 addition & 10 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,6 @@ message GeneratedText {
optional uint64 seed = 5;
}

message PrefillTokens {
/// Prefill Token IDs
repeated uint32 ids = 1;
/// Prefill Logprobs
repeated float logprobs = 2;
/// Prefill tokens
repeated string texts = 3;
}

message AlternativeTokens {
/// Alternative Token IDs
repeated uint32 ids = 1;
Expand All @@ -249,7 +240,7 @@ message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
PrefillTokens prefill_tokens = 2;
NextTokens prefill_tokens = 2;
/// Next tokens
NextTokens next_tokens = 3;
/// Complete generated text
Expand Down
3 changes: 1 addition & 2 deletions router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ pub use pb::generate::v1::{
input_chunk, AdapterParameters, AlternativeTokens, Batch, CachedBatch, ClassifyPredictionList,
DownloadAdapterResponse, Embedding, Entity, EntityList, FinishReason, GeneratedText,
Generation, Image, InputChunk, MajoritySignMethod, MergeStrategy, NextTokenChooserParameters,
NextTokens, PrefillTokens, PreloadedAdapter, Request, StoppingCriteriaParameters,
TokenizedInputs,
NextTokens, PreloadedAdapter, Request, StoppingCriteriaParameters, TokenizedInputs,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
Expand Down
4 changes: 2 additions & 2 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use itertools::izip;
use itertools::multizip;
use lorax_client::{
Batch, CachedBatch, ClassifyPredictionList, ClientError, Embedding, GeneratedText, Generation,
PrefillTokens, PreloadedAdapter, ShardedClient,
NextTokens, PreloadedAdapter, ShardedClient,
};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
Expand Down Expand Up @@ -1661,7 +1661,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill {
tokens: Option<PrefillTokens>,
tokens: Option<NextTokens>,
tokens_length: u32,
prefill_time: Instant,
},
Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
GeneratedText,
Generation,
NextTokens,
PrefillTokens,
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import NextTokenChooser, Sampling, StoppingCriteria
Expand Down Expand Up @@ -697,7 +696,7 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(prefill_token_ids, prefill_logprobs, prefill_texts)
prefill_tokens = NextTokens(prefill_token_ids, prefill_logprobs, prefill_texts, None, None)
prefill_tokens_length = len(prefill_tokens.token_ids)
else:
prefill_tokens = None
Expand Down
5 changes: 3 additions & 2 deletions server/lorax_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
GeneratedText,
Generation,
NextTokens,
PrefillTokens,
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import NextTokenChooser, Sampling, StoppingCriteria
Expand Down Expand Up @@ -662,10 +661,12 @@ def generate_token(self, batch: Seq2SeqLMBatch) -> Tuple[List[Generation], Optio

# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
prefill_tokens = NextTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
None,
None
)
prefill_tokens_length = len(prefill_tokens.token_ids)
else:
Expand Down
17 changes: 2 additions & 15 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,6 @@ def to_pb(self) -> generate_pb2.GeneratedText:
)


@dataclass
class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]

def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(ids=self.token_ids, logprobs=self.logprobs, texts=self.texts)

def __len__(self):
return len(self.token_ids)


@dataclass
class AlternativeTokens:
token_ids: List[int]
Expand All @@ -98,7 +85,7 @@ class NextTokens:
is_special: List[bool]
alternative_tokens: Optional[List[AlternativeTokens]]

def to_pb(self) -> generate_pb2.PrefillTokens:
def to_pb(self) -> generate_pb2.NextTokens:
return generate_pb2.NextTokens(
ids=self.token_ids,
logprobs=self.logprobs,
Expand All @@ -118,7 +105,7 @@ def __len__(self):
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
prefill_tokens: Optional[NextTokens]
prefill_tokens_length: int
next_tokens: NextTokens
generated_text: Optional[GeneratedText]
Expand Down

0 comments on commit 2af302d

Please sign in to comment.