Skip to content

Commit

Permalink
update batch from_pb sig
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang committed Feb 9, 2024
1 parent 262aef1 commit 31f1ca3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
tokenizer_mgr: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_inputs = []
max_truncation = 0
for r in pb.requests:
inputs = tokenizers.get_inputs(r, tokenizer)
inputs = tokenizer_mgr.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -240,7 +240,7 @@ def from_pb(
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer)
for r in pb.requests
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
tokenizer_mgr: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
Expand All @@ -67,7 +67,7 @@ def from_pb(
batch_inputs = []
max_truncation = 0
for r in pb.requests:
inputs = tokenizers.get_inputs(r, tokenizer)
inputs = tokenizer_mgr.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -206,7 +206,7 @@ def from_pb(
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer)
for r in pb.requests
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
tokenizer_mgr: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
Expand All @@ -74,7 +74,7 @@ def from_pb(
batch_inputs = []
max_truncation = 0
for r in pb.requests:
inputs = tokenizers.get_inputs(r, tokenizer)
inputs = tokenizer_mgr.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

Expand Down Expand Up @@ -213,7 +213,7 @@ def from_pb(
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer)
for r in pb.requests
]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from lorax_server.pb import generate_pb2
from lorax_server.pb.generate_pb2 import FinishReason
from lorax_server.utils.tokenizer import TokenizerManager


class Batch(ABC):
Expand All @@ -21,6 +22,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizer_mgr: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
Expand Down

0 comments on commit 31f1ca3

Please sign in to comment.