diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b1420f49e..eafc0ded0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -446,6 +446,10 @@ struct Args { /// The backend to use for the model. Can be `fa2` or `flashinfer`. #[clap(default_value = "fa2", long, env, value_enum)] backend: Backend, + + /// The embedding dimension to use for the model. + #[clap(long, env)] + embedding_dim: Option, } #[derive(Debug)] @@ -487,6 +491,7 @@ fn shard_manager( status_sender: mpsc::Sender, shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, + embedding_dim: Option, ) { // Enter shard-manager tracing span let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); @@ -585,6 +590,15 @@ fn shard_manager( shard_args.push(otlp_endpoint); } + // Embedding dimension + if let Some(embedding_dim) = embedding_dim { + shard_args.push("--embedding-dim".to_string()); + shard_args.push(embedding_dim.to_string()) + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Torch Distributed Env vars envs.push(("RANK".into(), rank.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -910,6 +924,12 @@ fn download_convert_model( download_args.push(adapter_id.to_string()); } + // Embedding dimension + if let Some(embedding_dim) = args.embedding_dim { + download_args.push("--embedding-dim".to_string()); + download_args.push(embedding_dim.to_string()) + } + // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -1054,6 +1074,7 @@ fn spawn_shards( let adapter_memory_fraction = args.adapter_memory_fraction; let prefix_caching = args.prefix_caching; let backend = args.backend; + let embedding_dim = args.embedding_dim; thread::spawn(move || { shard_manager( model_id, @@ -1087,6 +1108,7 @@ fn spawn_shards( status_sender, shutdown, shutdown_sender, + embedding_dim, ) }); } diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 1024a1694..168a37dfd 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -49,6 +49,7 @@ def serve( speculative_tokens: int = 0, preloaded_adapter_ids: Optional[List[str]] = typer.Option(None), preloaded_adapter_source: Optional[str] = None, + embedding_dim: Optional[int] = None, ): preloaded_adapter_ids = preloaded_adapter_ids or [] preloaded_adapter_source = preloaded_adapter_source or adapter_source @@ -101,6 +102,7 @@ def serve( speculative_tokens, preloaded_adapter_ids, preloaded_adapter_source, + embedding_dim, ) @@ -116,6 +118,7 @@ def download_weights( adapter_id: str = "", adapter_source: str = "hub", api_token: Optional[str] = None, + embedding_dim: Optional[int] = None, ): # Remove default handler logger.remove() @@ -128,7 +131,7 @@ def download_weights( backtrace=True, diagnose=False, ) - _download_weights(model_id, revision, extension, auto_convert, source, api_token) + _download_weights(model_id, revision, extension, auto_convert, source, api_token, embedding_dim) if adapter_id: _download_weights(adapter_id, revision, extension, auto_convert, adapter_source, api_token) diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index a68dfaa20..9fbe7e277 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -52,6 +52,7 @@ def get_model( trust_remote_code: bool, source: str, adapter_source: str, + embedding_dim: Optional[int] = None, ) -> Model: config_dict = None if source == "s3": @@ -255,6 +256,7 @@ def get_model( compile=compile, dtype=dtype, trust_remote_code=trust_remote_code, + embedding_dim=embedding_dim, ) if model_type in ["phi-msft", "phi"]: diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 0ab6ec73c..cb65f60a0 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -419,7 +419,6 @@ def forward( # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids, max_s, hidden_states.dtype) - residual = None for i, layer in enumerate(self.layers): hidden_states, residual = layer( @@ -496,7 +495,64 @@ def forward( prefill_cache_indices, adapter_data, ) + if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states, adapter_data) return logits, speculative_logits + + +class FlashQwen2ForEmbeddings(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.config = config + + self.model = FlashQwen2Model(config, weights) + self.max_past = config.sliding_window + self.output_weight = weights.get_tensor("linear.weight") + self.output_bias = weights.get_tensor("linear.bias") + # To satisfy the parent class interface + # TODO: fix + self.lm_head = None + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + adapter_data: AdapterBatchData, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(self.max_past, max_s) + input_lengths = torch.clamp(input_lengths, max=self.max_past) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + adapter_data, + ) + batch_size = hidden_states.shape[0] // max_s + hidden_states = hidden_states.reshape(batch_size, max_s, -1) + mean_hidden_states = hidden_states.mean(1) + embeddings = nn.functional.linear(mean_hidden_states, self.output_weight, self.output_bias) + return embeddings, None + diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index f137994cb..c193d8ecf 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -125,6 +125,13 @@ def to_pb(self) -> generate_pb2.CachedBatch: max_tokens=self.num_blocks * BLOCK_SIZE, ) + @classmethod + def to_pb_embed(self, batch, embeddings) -> generate_pb2.EmbedResponse: + embeddings_proto = [] + for i, embedding in enumerate(embeddings): + embeddings_proto.append(generate_pb2.Embedding(request_id=batch.requests[i].id, values=embedding)) + return generate_pb2.EmbedResponse(embeddings=embeddings_proto) + @classmethod def from_pb( cls, @@ -399,6 +406,10 @@ def from_pb( ), prefill_cache_indices=prefill_cache_indices if SLIDING_WINDOW is not None else None, ) + + @classmethod + def from_pb_embed(self, pb: generate_pb2.EmbedRequest, tokenizer: PreTrainedTokenizerBase, tokenizers: TokenizerManager, processor, config, dtype, device) -> "FlashCausalLMBatch": + return self.from_pb(pb, tokenizer, tokenizers, None, None, dtype, device) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": @@ -873,7 +884,7 @@ def adapter_memory_size(self) -> int: total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory return ADAPTER_MEMORY_FRACTION * total_gpu_memory - def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): + def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model: bool = False): max_total_tokens = batch.max_seqlen + max_new_tokens + get_speculative_tokens() torch.cuda.empty_cache() @@ -887,17 +898,18 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): self.device, ) - with warmup_mode(): - logger.info("Warming up to max_new_tokens: {}", max_new_tokens) - with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar: - for _ in range(max_new_tokens): - cur_seqlen = batch.max_seqlen - _, batch = self.generate_token(batch, is_warmup=True) - new_seqlen = batch.max_seqlen - pbar.update(new_seqlen - cur_seqlen) - if new_seqlen >= max_total_tokens - get_speculative_tokens(): - break - logger.info("Finished generating warmup tokens") + if not embedding_model: + with warmup_mode(): + logger.info("Warming up to max_new_tokens: {}", max_new_tokens) + with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar: + for _ in range(max_new_tokens): + cur_seqlen = batch.max_seqlen + _, batch = self.generate_token(batch, is_warmup=True) + new_seqlen = batch.max_seqlen + pbar.update(new_seqlen - cur_seqlen) + if new_seqlen >= max_total_tokens - get_speculative_tokens(): + break + logger.info("Finished generating warmup tokens") except RuntimeError as e: if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError): raise RuntimeError( diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 37158d0db..7b524a401 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer from transformers.models.qwen2 import Qwen2Config +from lorax_server.adapters import AdapterBatchData from lorax_server.models import FlashCausalLM from lorax_server.models.custom_modeling.flash_qwen2_modeling import ( ATTN_K_PROJ, @@ -16,6 +17,7 @@ MLP_GATE_PROJ, MLP_UP_PROJ, FlashQwen2ForCausalLM, + FlashQwen2ForEmbeddings, ) from lorax_server.utils import ( Weights, @@ -51,6 +53,7 @@ def __init__( compile: bool = False, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + embedding_dim: Optional[int] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -74,7 +77,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") + filenames = weight_files(model_id, revision=revision, extension=".safetensors", embedding_dim=embedding_dim) weights = Weights( filenames, device, @@ -83,7 +86,19 @@ def __init__( ) weights._set_config(model_id, config) - model = FlashQwen2ForCausalLM(config, weights) + self._supports_embeddings = embedding_dim is not None + + if not weights.has_tensor("lm_head.weight") and not self._supports_embeddings: + raise ValueError( + "Model does not have lm head so it is presumed to be for embeddings." + "No embedding_dim was provided so we cannot load the model." + "Please pass in an embedding_dim to the model." + ) + + if self._supports_embeddings: + model = FlashQwen2ForEmbeddings(config, weights) + else: + model = FlashQwen2ForCausalLM(config, weights) self.config = config @@ -111,6 +126,14 @@ def __init__( def supports_adapter_loading(self) -> bool: return True + @property + def supports_embeddings(self) -> bool: + return self._supports_embeddings + + @property + def supports_text_generation(self) -> bool: + return not self._supports_embeddings + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} @@ -156,3 +179,15 @@ def get_num_layers_for_type(self, layer_type: str) -> int: def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL + + def embed(self, batch) -> torch.Tensor: + adapter_meta = batch.adapter_meta + prefill = False + adapter_data = AdapterBatchData.from_meta( + adapter_meta, self.layer_to_adapter_weights, prefill, batch.prefill_head_indices + ) + embedding, _ = self.forward(batch, adapter_data=adapter_data) + return embedding.cpu().tolist() + + def warmup(self, batch, max_new_tokens): + return super().warmup(batch, max_new_tokens, embedding_model=self._supports_embeddings) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 1479f2875..78a564a0a 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -127,7 +127,7 @@ async def Embed(self, request: generate_pb2.EmbedRequest, context): if not self.model.supports_embeddings: raise ValueError("Model does not support embeddings") - batch = self.model.batch_type.from_pb( + batch = self.model.batch_type.from_pb_embed( request.batch, self.model.tokenizer, self.model.tokenizers, @@ -249,6 +249,7 @@ def serve( speculative_tokens: int, preloaded_adapter_ids: List[str], preloaded_adapter_source: str, + embedding_dim: Optional[int] = None, ): async def serve_inner( model_id: str, @@ -263,6 +264,7 @@ async def serve_inner( speculative_tokens: int, preloaded_adapter_ids: List[str], preloaded_adapter_source: str, + embedding_dim: Optional[int] = None, ): unix_socket_template = "unix://{}-{}" if sharded: @@ -284,6 +286,7 @@ async def serve_inner( trust_remote_code, source, adapter_source, + embedding_dim, ) except Exception: logger.exception("Error when initializing model") @@ -421,5 +424,6 @@ def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool: speculative_tokens, preloaded_adapter_ids, preloaded_adapter_source, + embedding_dim, ) ) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index ba9af397a..47cc566a5 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -76,9 +76,10 @@ def get_model_source( revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None, + embedding_dim: Optional[int] = None, ): if source == HUB: - return HubModelSource(model_id, revision, extension, api_token) + return HubModelSource(model_id, revision, extension, api_token, embedding_dim) elif source == S3: return S3ModelSource(model_id, revision, extension) elif source == LOCAL: diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index 0b6f6bb6f..fb6b14e5a 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -28,19 +28,41 @@ def weight_hub_files( revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None, + embedding_dim: Optional[int] = None, ) -> List[str]: """Get the weights filenames on the hub""" api = get_hub_api(token=api_token) info = api.model_info(model_id, revision=revision) - filenames = [ - s.rfilename - for s in info.siblings - if s.rfilename.endswith(extension) - and len(s.rfilename.split("/")) == 1 - and "arguments" not in s.rfilename - and "args" not in s.rfilename - and "training" not in s.rfilename - ] + if embedding_dim is not None: + filenames = [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) <= 2 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + ] + # Only include the layer for the correct embedding dim + embedding_tensor_file = f"2_Dense_{embedding_dim}/model.safetensors" + if embedding_tensor_file not in filenames: + raise ValueError(f"No embedding tensor file found for embedding dim {embedding_dim}") + filenames = [ + filename + for filename in filenames + if len(filename.split("/")) < 2 + or filename == embedding_tensor_file + ] + else: + filenames = [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + ] if not filenames: raise EntryNotFoundError( @@ -56,6 +78,7 @@ def weight_files( revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None, + embedding_dim: Optional[int] = None, ) -> List[Path]: """Get the local files""" # Local model @@ -66,7 +89,7 @@ def weight_files( return local_files try: - filenames = weight_hub_files(model_id, revision, extension, api_token) + filenames = weight_hub_files(model_id, revision, extension, api_token, embedding_dim) except EntryNotFoundError as e: if extension != ".safetensors": raise e @@ -160,11 +183,13 @@ def __init__( revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None, + embedding_dim: Optional[int] = None, ): self.model_id = model_id self.revision = revision self.extension = extension self._api_token = api_token + self.embedding_dim = embedding_dim @property def api_token(self) -> Optional[str]: @@ -172,11 +197,11 @@ def api_token(self) -> Optional[str]: def remote_weight_files(self, extension: str = None): extension = extension or self.extension - return weight_hub_files(self.model_id, self.revision, extension, self.api_token) + return weight_hub_files(self.model_id, self.revision, extension, self.api_token, self.embedding_dim) def weight_files(self, extension=None): extension = extension or self.extension - return weight_files(self.model_id, self.revision, extension, self.api_token) + return weight_files(self.model_id, self.revision, extension, self.api_token, self.embedding_dim) def download_weights(self, filenames): return download_weights(filenames, self.model_id, self.revision, self.api_token) diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index a3185bd6c..cbdb33401 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -461,12 +461,13 @@ def download_weights( auto_convert: bool = True, source: str = "hub", api_token: Optional[str] = None, + embedding_dim: Optional[int] = None, ): # Import here after the logger is added to log potential import exceptions from lorax_server import utils from lorax_server.utils import sources - model_source = sources.get_model_source(source, model_id, revision, extension, api_token) + model_source = sources.get_model_source(source, model_id, revision, extension, api_token, embedding_dim) # Test if files were already download try: