From 652d4971a03a6332e8f86864e2c16b57b04237a4 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Thu, 6 Jun 2024 17:32:58 +0000 Subject: [PATCH] add warmup method --- server/lorax_server/models/flash_bert.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 81699c38f..d148b3471 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -172,7 +172,8 @@ def __init__( else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - self.device = device + # self.device = device + self.device = "cpu" self.dtype = dtype tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -214,7 +215,10 @@ def supports_text_generation(self) -> bool: return False def warmup(self, batch: FlashEmbeddingBatch, max_new_tokens: int) -> int | None: - return 42 # no-op for now + # Note: This is meant to 1) preallocate the memory by doing a forward pass + # and then just returning the max seqlen since for embeddings we are never generating + _ = self.embed(batch) + return batch.max_s def generate_token(self, batch: FlashEmbeddingBatch) -> None: if not self.supports_text_generation: