Skip to content

Commit

Permalink
Fixed warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jun 1, 2024
1 parent 7898dba commit d3a84e9
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import torch
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -155,21 +156,24 @@ def from_pb(
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
batch_inputs = tokenizer(
batch_inputs,
return_token_type_ids=True,
truncation=True,
max_length=max_truncation,
)
batch_tokenized_inputs = batch_inputs["input_ids"]

max_s = 0
all_input_ids = []
position_ids = []
cu_seqlens = [0]

max_s = 0
cumulative_length = 0
cu_seqlens = [0]

for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)):
tokenized_input = tokenized_input[-r.truncate :]
all_input_ids.append(tokenized_input)

input_length = len(tokenized_input)
max_s = max(max_s, input_length)
Expand All @@ -180,11 +184,21 @@ def from_pb(
position_ids.append(request_position_ids)

cumulative_length += input_length

if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]

input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
position_ids = position_ids.to(device)

return FlashEmbeddingBatch(
input_ids=torch.tensor(batch_tokenized_inputs["input_ids"], dtype=torch.int32, device=device),
token_type_ids=torch.tensor(batch_tokenized_inputs["token_type_ids"], dtype=torch.int32, device=device),
position_ids=torch.tensor(position_ids, dtype=torch.int32, device=device),
input_ids=input_ids,
token_type_ids=torch.tensor(batch_inputs["token_type_ids"], dtype=torch.int32, device=device),
position_ids=position_ids,
cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device),
max_s=max_s,
size=len(batch_inputs),
Expand Down

0 comments on commit d3a84e9

Please sign in to comment.