Skip to content

Commit

Permalink
revert tokenizer_mgr rename and add handling for causallm
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang committed Feb 9, 2024
1 parent 31f1ca3 commit 29573ce
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 72 deletions.
98 changes: 50 additions & 48 deletions server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from collections import defaultdict
import json
import torch
import inspect

from dataclasses import dataclass
from typing import Optional, Tuple, List, Type, Dict

import torch
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict

from lorax_server.models import Model
from lorax_server.models.types import (
Expand All @@ -17,9 +14,9 @@
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights
from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -95,7 +92,10 @@ def from_pb(
requests_idx_mapping[r.id] = i
req_inputs = tokenizers.get_inputs(r, tokenizer)
inputs.append(req_inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters,
device,
tokenizers.get_tokenizer(adapter_indices_list[i],
tokenizer)))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
Expand Down Expand Up @@ -212,7 +212,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
Expand All @@ -226,12 +226,14 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
position_ids = self.position_ids[keep_indices]
adapter_indices = self.adapter_meta.adapter_indices[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
keep_indices,
-(self.padding_right_offset + max_input_length): (
self.attention_mask.shape[
1] -
self.padding_right_offset
)
+ new_padding_right_offset,
]

# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
Expand Down Expand Up @@ -371,17 +373,17 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
batch_left_offset: -batch.padding_right_offset,
]

# Create empty tensor
# position_ids is always of shape [batch_size, 1]
Expand All @@ -405,7 +407,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":

# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
max_input_length - batch.max_input_length
) * len(batch)

start_index = end_index
Expand Down Expand Up @@ -447,12 +449,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
del past_keys

Expand All @@ -472,7 +474,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
start_index:end_index, :, -past_seq_len:, :
] = past_values[:, :, -past_seq_len:, :]
del past_values

Expand Down Expand Up @@ -525,7 +527,7 @@ def __init__(
):
if compile:
raise ValueError("`--compile` is not supported with CausalLM")

if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
Expand Down Expand Up @@ -580,7 +582,7 @@ def __init__(
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch

@property
def has_adapter_data(self) -> bool:
return False
Expand All @@ -592,19 +594,19 @@ def decode(self, generated_ids: List[int]) -> str:

def forward(
self,
input_ids,
attention_mask,
position_ids,
past_key_values: Optional = None,
input_ids,
attention_mask,
position_ids,
past_key_values: Optional = None,
adapter_data: Optional[AdapterBatchData] = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"input_ids" : input_ids,
"attention_mask" : attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
"use_cache" : True,
"return_dict" : True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
Expand Down Expand Up @@ -653,14 +655,14 @@ def generate_token(

# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
Expand Down Expand Up @@ -693,7 +695,7 @@ def generate_token(
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
all_input_ids[-stopping_criteria.current_tokens:, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
Expand All @@ -713,8 +715,8 @@ def generate_token(
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizer_mgr: TokenizerManager,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_inputs = []
max_truncation = 0
for r in pb.requests:
inputs = tokenizer_mgr.get_inputs(r, tokenizer)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -240,7 +240,7 @@ def from_pb(
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [
tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer)
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
for r in pb.requests
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizer_mgr: TokenizerManager,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
Expand All @@ -67,7 +67,7 @@ def from_pb(
batch_inputs = []
max_truncation = 0
for r in pb.requests:
inputs = tokenizer_mgr.get_inputs(r, tokenizer)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -206,7 +206,7 @@ def from_pb(
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [
tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer)
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
for r in pb.requests
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizer_mgr: TokenizerManager,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
Expand All @@ -74,7 +74,7 @@ def from_pb(
batch_inputs = []
max_truncation = 0
for r in pb.requests:
inputs = tokenizer_mgr.get_inputs(r, tokenizer)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -213,7 +213,7 @@ def from_pb(
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [
tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer)
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
for r in pb.requests
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizer_mgr: TokenizerManager,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
Expand Down
8 changes: 4 additions & 4 deletions server/lorax_server/utils/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,16 @@ class HeterogeneousSchemaLogitsProcessor(LogitsProcessor):

def __init__(
self,
schemas: List[Optional[str]] = None,
tokenizers: List[Optional[PreTrainedTokenizerBase]] = None,
schemas: Optional[List[Optional[str]]] = None,
tokenizers: Optional[List[Optional[PreTrainedTokenizerBase]]] = None,
):
if schemas is None:
schemas = []
if tokenizers is None:
tokenizers = []

self.sequence_processors = [
None if schema is None else OutlinesLogitsProcessor(schema, tokenizer)
None if schema is None or tokenizer is None else OutlinesLogitsProcessor(schema, tokenizer)
for schema, tokenizer in zip(schemas, tokenizers)
]

Expand Down Expand Up @@ -469,7 +469,7 @@ def concatenate(


# Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
class OutlinesLogitsProcessor:
class OutlinesLogitsProcessor(LogitsProcessor):
def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-guided generation.
Expand Down
Loading

0 comments on commit 29573ce

Please sign in to comment.