Skip to content

Commit

Permalink
add warmup method
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Jun 6, 2024
1 parent bfac3b8 commit 652d497
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def __init__(
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

self.device = device
# self.device = device
self.device = "cpu"
self.dtype = dtype

tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -214,7 +215,10 @@ def supports_text_generation(self) -> bool:
return False

def warmup(self, batch: FlashEmbeddingBatch, max_new_tokens: int) -> int | None:
return 42 # no-op for now
# Note: This is meant to 1) preallocate the memory by doing a forward pass
# and then just returning the max seqlen since for embeddings we are never generating
_ = self.embed(batch)
return batch.max_s

def generate_token(self, batch: FlashEmbeddingBatch) -> None:
if not self.supports_text_generation:
Expand Down

0 comments on commit 652d497

Please sign in to comment.