Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Jun 6, 2024
1 parent 652d497 commit 9c688c2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,4 @@ def embed(self, batch: FlashEmbeddingBatch) -> Embedding:
embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size]

cpu_results = embedding.cpu().tolist()
return Embedding(values=cpu_results)
return cpu_results
17 changes: 13 additions & 4 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def to_pb(self) -> generate_pb2.GeneratedText:
seed=self.seed,
)


@dataclass
class PrefillTokens:
token_ids: List[int]
Expand Down Expand Up @@ -129,6 +128,7 @@ def to_pb(self) -> generate_pb2.Generation:

@dataclass
class FlashEmbeddingBatch(ABC):
request_ids: List[int]
input_ids: torch.Tensor
token_type_ids: torch.Tensor
position_ids: torch.Tensor
Expand Down Expand Up @@ -162,18 +162,23 @@ def from_pb(
truncation=True,
max_length=max_truncation,
)

batch_tokenized_inputs = batch_inputs["input_ids"]
batch_token_type_ids = batch_inputs["token_type_ids"]

all_input_ids = []
position_ids = []
all_token_type_ids = []
cu_seqlens = [0]

max_s = 0
cumulative_length = 0

for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)):
for i, (r, tokenized_input, token_type_ids) in enumerate(zip(pb.requests, batch_tokenized_inputs, batch_token_type_ids)):
tokenized_input = tokenized_input[-r.truncate :]
token_type_ids = token_type_ids[-r.truncate :]
all_input_ids.append(tokenized_input)
all_token_type_ids.append(token_type_ids)

input_length = len(tokenized_input)
max_s = max(max_s, input_length)
Expand All @@ -187,17 +192,21 @@ def from_pb(

if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
final_token_type_ids = np.concatenate(all_token_type_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
else:
input_ids = all_input_ids[0]
final_token_type_ids = all_token_type_ids[0]
position_ids = position_ids[0]

input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
final_token_type_ids = torch.tensor(final_token_type_ids, dtype=torch.int64, device=device)
position_ids = position_ids.to(device)

return FlashEmbeddingBatch(
request_ids=[r.id for r in pb.requests],
input_ids=input_ids,
token_type_ids=torch.tensor(batch_inputs["token_type_ids"], dtype=torch.int32, device=device),
token_type_ids=final_token_type_ids,
position_ids=position_ids,
cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device),
max_s=max_s,
Expand Down
6 changes: 4 additions & 2 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context):
)

async def Embed(self, request: generate_pb2.EmbedRequest, context):
print("!!! EMBED")
if not self.model.supports_embeddings:
raise ValueError("Model does not support embeddings")

Expand All @@ -108,7 +107,10 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context):
self.model.device,
)
embeddings = self.model.embed(batch)
return generate_pb2.EmbedResponse(embeddings=embeddings)
embeddings_proto = []
for i, embedding in enumerate(embeddings):
embeddings_proto.append(generate_pb2.Embedding(request_id=batch.request_ids[i], values=embedding))
return generate_pb2.EmbedResponse(embeddings=embeddings_proto)

async def Decode(self, request: generate_pb2.DecodeRequest, context):
if len(request.batches) == 0:
Expand Down

0 comments on commit 9c688c2

Please sign in to comment.