diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index d148b3471..beaa7f904 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -240,4 +240,4 @@ def embed(self, batch: FlashEmbeddingBatch) -> Embedding: embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size] cpu_results = embedding.cpu().tolist() - return Embedding(values=cpu_results) + return cpu_results diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 84ca35c6c..18accb3ff 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -57,7 +57,6 @@ def to_pb(self) -> generate_pb2.GeneratedText: seed=self.seed, ) - @dataclass class PrefillTokens: token_ids: List[int] @@ -129,6 +128,7 @@ def to_pb(self) -> generate_pb2.Generation: @dataclass class FlashEmbeddingBatch(ABC): + request_ids: List[int] input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor @@ -162,18 +162,23 @@ def from_pb( truncation=True, max_length=max_truncation, ) + batch_tokenized_inputs = batch_inputs["input_ids"] + batch_token_type_ids = batch_inputs["token_type_ids"] all_input_ids = [] position_ids = [] + all_token_type_ids = [] cu_seqlens = [0] max_s = 0 cumulative_length = 0 - - for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): + + for i, (r, tokenized_input, token_type_ids) in enumerate(zip(pb.requests, batch_tokenized_inputs, batch_token_type_ids)): tokenized_input = tokenized_input[-r.truncate :] + token_type_ids = token_type_ids[-r.truncate :] all_input_ids.append(tokenized_input) + all_token_type_ids.append(token_type_ids) input_length = len(tokenized_input) max_s = max(max_s, input_length) @@ -187,17 +192,21 @@ def from_pb( if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) + final_token_type_ids = np.concatenate(all_token_type_ids, dtype=np.int64) position_ids = torch.cat(position_ids) else: input_ids = all_input_ids[0] + final_token_type_ids = all_token_type_ids[0] position_ids = position_ids[0] input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + final_token_type_ids = torch.tensor(final_token_type_ids, dtype=torch.int64, device=device) position_ids = position_ids.to(device) return FlashEmbeddingBatch( + request_ids=[r.id for r in pb.requests], input_ids=input_ids, - token_type_ids=torch.tensor(batch_inputs["token_type_ids"], dtype=torch.int32, device=device), + token_type_ids=final_token_type_ids, position_ids=position_ids, cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), max_s=max_s, diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index f98d7a9f8..63cc39c77 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -96,7 +96,6 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context): ) async def Embed(self, request: generate_pb2.EmbedRequest, context): - print("!!! EMBED") if not self.model.supports_embeddings: raise ValueError("Model does not support embeddings") @@ -108,7 +107,10 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): self.model.device, ) embeddings = self.model.embed(batch) - return generate_pb2.EmbedResponse(embeddings=embeddings) + embeddings_proto = [] + for i, embedding in enumerate(embeddings): + embeddings_proto.append(generate_pb2.Embedding(request_id=batch.request_ids[i], values=embedding)) + return generate_pb2.EmbedResponse(embeddings=embeddings_proto) async def Decode(self, request: generate_pb2.DecodeRequest, context): if len(request.batches) == 0: