Skip to content

Commit

Permalink
Fix shape
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jun 1, 2024
1 parent 5f0cac7 commit 7898dba
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ def forward(self, batch: FlashEmbeddingBatch):

@tracer.start_as_current_span("embed")
def embed(self, batch: FlashEmbeddingBatch) -> Embedding:
embedding = self.model.forward(
embedding: torch.Tensor = self.model.forward(
input_ids=batch.input_ids,
token_type_ids=batch.token_type_ids,
position_ids=batch.position_ids,
cu_seqlens=batch.cu_seqlens,
max_s=batch.max_s,
)
cpu_results = embedding.view(-1).tolist()
embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size]

return Embedding(values=cpu_results[: self.hidden_size])
cpu_results = embedding.cpu().tolist()
return Embedding(values=cpu_results)

0 comments on commit 7898dba

Please sign in to comment.