diff --git a/router/src/config.rs b/router/src/config.rs index 0be0b9b83..cf1eb7b78 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -175,6 +175,8 @@ pub enum Config { T5, Bert, Distilbert, + #[serde(rename = "xlm-roberta")] + XlmRoberta, } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/router/src/infer.rs b/router/src/infer.rs index 74788bda0..5626302c3 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -489,34 +489,31 @@ impl Infer { err })?; - // TODO(travis): support adapters - // let (adapter_source, adapter_parameters) = extract_adapter_params( - // request.parameters.adapter_id.clone(), - // request.parameters.adapter_source.clone(), - // request.parameters.adapter_parameters.clone(), - // ); - - // let adapter_idx; - // { - // // TODO(travis): can optimize concurrency here using RWLock - // let mut adapter_to_index = self.adapter_to_index.lock().await; - // let adapter_key = adapter_parameters.clone(); - // if adapter_to_index.contains_key(&adapter_key) { - // adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); - // } else { - // adapter_idx = adapter_to_index.len() as u32; - // adapter_to_index.insert(adapter_key, adapter_idx); - // } - // } + let (adapter_source, adapter_parameters) = extract_adapter_params( + request.parameters.adapter_id.clone(), + request.parameters.adapter_source.clone(), + request.parameters.adapter_parameters.clone(), + ); + + let adapter_idx; + { + // TODO(travis): can optimize concurrency here using RWLock + let mut adapter_to_index = self.adapter_to_index.lock().await; + let adapter_key = adapter_parameters.clone(); + if adapter_to_index.contains_key(&adapter_key) { + adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); + } else { + adapter_idx = adapter_to_index.len() as u32; + adapter_to_index.insert(adapter_key, adapter_idx); + } + } + let api_token = request.parameters.api_token.clone(); let adapter = Adapter::new( - AdapterParameters { - adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()], - ..Default::default() - }, - "hub".to_string(), - 0, - None, + adapter_parameters, + adapter_source.unwrap(), + adapter_idx, + api_token, ); // TODO(travis): robust validation diff --git a/router/src/lib.rs b/router/src/lib.rs index 40253d5df..3eaeeca3f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -921,9 +921,44 @@ pub(crate) enum CompletionFinishReason { ToolCalls, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct EmbedParameters { + #[serde(default)] + #[schema( + nullable = true, + default = "null", + example = "arnavgrg/codealpaca-qlora" + )] + pub adapter_id: Option, + #[serde(default)] + #[schema(nullable = true, default = "null", example = "hub")] + pub adapter_source: Option, + #[serde(rename(deserialize = "merged_adapters"))] + #[schema(nullable = true, default = "null")] + pub adapter_parameters: Option, + #[serde(default)] + #[schema( + nullable = true, + default = "null", + example = "" + )] + pub api_token: Option, +} + +fn default_embed_parameters() -> EmbedParameters { + EmbedParameters { + adapter_id: None, + adapter_source: None, + adapter_parameters: None, + api_token: None, + } +} + #[derive(Clone, Debug, Deserialize, ToSchema)] struct EmbedRequest { inputs: String, + #[serde(default = "default_embed_parameters")] + pub parameters: EmbedParameters, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index e2806dd87..3370114e9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -11,11 +11,12 @@ use crate::{ ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, - CompletionStreamResponse, Details, EmbedRequest, EmbedResponse, Entity, ErrorResponse, - FinishReason, FunctionDefinition, GenerateParameters, GenerateRequest, GenerateResponse, - HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, OpenAiResponseFormat, PrefillToken, - ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token, - TokenizeRequest, TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, + CompletionStreamResponse, Details, EmbedParameters, EmbedRequest, EmbedResponse, Entity, + ErrorResponse, FinishReason, FunctionDefinition, GenerateParameters, GenerateRequest, + GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, + OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, + StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, Tool, ToolCall, + ToolChoice, UsageInfo, Validation, }; use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use axum::extract::Extension; @@ -572,6 +573,12 @@ async fn health( if health.shard_info().supports_embeddings { let embed_request = EmbedRequest { inputs: "San Francisco".to_string(), + parameters: EmbedParameters { + adapter_id: None, + adapter_source: None, + adapter_parameters: None, + api_token: None, + }, }; match infer.embed(embed_request).await { Ok(_) => {} diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index d26e3d187..ea9197fa1 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -120,6 +120,18 @@ def get_model( if config_dict["architectures"][0] == "DistilBertForTokenClassification": return FlashDistilBert(model_id, revision=revision, dtype=dtype, classifcation_head=True) + if model_type == "xlm-roberta": + from lorax_server.models.flash_roberta import FlashXlmRoberta + + return FlashXlmRoberta( + model_id, + adapter_id, + adapter_source, + revision=revision, + dtype=dtype, + merge_adapter_weights=merge_adapter_weights, + ) + flash_causal_lm_kwargs = dict( quantize=quantize, compile=compile, diff --git a/server/lorax_server/models/custom_modeling/flash_roberta_modeling.py b/server/lorax_server/models/custom_modeling/flash_roberta_modeling.py new file mode 100644 index 000000000..e92b0cae1 --- /dev/null +++ b/server/lorax_server/models/custom_modeling/flash_roberta_modeling.py @@ -0,0 +1,145 @@ +import torch +from torch import nn +from transformers.activations import ACT2FN + +from lorax_server.utils.flash_attn import attention +from lorax_server.utils.layers import ( + FastLayerNorm, + TensorParallelColumnLinear, + TensorParallelMultiAdapterLinear, +) + +ATTN_Q = "self.query" +ATTN_K = "self.key" +ATTN_V = "self.value" + + +class RobertaEmbeddings: + def __init__(self, prefix, weights, device, dtype, config): + self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device) + self.token_type_embeddings_weight = ( + weights.get_tensor(f"{prefix}.token_type_embeddings.weight").to(dtype).to(device) + ) + + if config.position_embedding_type == "absolute": + self.position_embeddings_weight = ( + weights.get_tensor(f"{prefix}.position_embeddings.weight").to(dtype).to(device) + ) + else: + raise NotImplementedError("FlashRoberta only supports absolute position embeddings") + self.pad_token_id = config.pad_token_id + + self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps) + + def forward(self, input_ids, token_type_ids, position_ids): + # position numbers begin at pad_token_id + 1 + # see transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids + position_ids += self.pad_token_id + 1 + + inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight) + token_type_embeds = nn.functional.embedding(token_type_ids, self.token_type_embeddings_weight) + position_embeds = nn.functional.embedding(position_ids, self.position_embeddings_weight) + + inputs_embeds += position_embeds + + embeddings, _ = self.layer_norm.forward(inputs_embeds, token_type_embeds) + return embeddings + + +class RobertaAttention: + def __init__(self, prefix, layer_id, weights, device, dtype, config): + self.query_key_value = RobertaAttention.load_attention(config, prefix, weights, layer_id) + + self.dense_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) + self.dense_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) + + self.layer_norm = FastLayerNorm.load( + prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps + ) + + self.head_size = config.hidden_size // config.num_attention_heads + self.softmax_scale = self.head_size**-0.5 + self.num_heads = config.num_attention_heads + self.layer_id = layer_id + + @staticmethod + def load_attention(config, prefix, weights, layer_id): + config.quantize = None + base_layer = RobertaAttention.load_attention_multi(config, prefix, weights) + return TensorParallelMultiAdapterLinear.load( + base_layer, + layer_id, + [ATTN_Q, ATTN_K, ATTN_V], + sizes=[ + config.hidden_size, + config.hidden_size, + config.hidden_size, + ], + process_group=weights.process_group, + ) + + @staticmethod + def load_attention_multi(config, prefix, weights): + prefixes = [f"{prefix}.{ATTN_Q}", f"{prefix}.{ATTN_K}", f"{prefix}.{ATTN_V}"] + return TensorParallelColumnLinear.load_multi( + config, + prefixes=prefixes, + dim=0, + weights=weights, + bias=True, + ) + + def forward(self, hidden_states, cu_seqlens, max_s, adapter_data): + residual = hidden_states + + qkv = self.query_key_value(hidden_states, adapter_data) + q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1) + + attn_output = attention(q, k, v, None, None, cu_seqlens, max_s, self.softmax_scale, causal=False) + + hidden_states = torch.addmm( + self.dense_bias, + attn_output.view(-1, self.num_heads * self.head_size), + self.dense_weight, + ) + hidden_states, _ = self.layer_norm.forward(hidden_states, residual) + + return hidden_states + + +class RobertaLayer: + def __init__(self, prefix, layer_id, weights, device, dtype, config): + self.attention = RobertaAttention(f"{prefix}.attention", layer_id, weights, device, dtype, config) + + self.intermediate_weight = weights.get_tensor(f"{prefix}.intermediate.dense.weight").T.to(dtype).to(device) + self.intermediate_bias = weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device) + + act = config.hidden_act + self.intermediate_act_fn = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", + ) + ) + + self.output_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device) + self.output_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device) + self.layer_norm = FastLayerNorm.load( + prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps + ) + + def forward(self, hidden_states, cu_seqlens, max_s, adapter_data): + hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s, adapter_data) + residual = hidden_states + + hidden_states = torch.addmm(self.intermediate_bias, hidden_states, self.intermediate_weight) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = torch.addmm( + self.output_bias, + hidden_states, + self.output_weight, + ) + hidden_states, _ = self.layer_norm.forward(hidden_states, residual) + return hidden_states diff --git a/server/lorax_server/models/flash_roberta.py b/server/lorax_server/models/flash_roberta.py new file mode 100644 index 000000000..8e6d41d7e --- /dev/null +++ b/server/lorax_server/models/flash_roberta.py @@ -0,0 +1,226 @@ +from contextlib import nullcontext +from typing import Any, ContextManager, Optional, Type + +import torch +from loguru import logger +from opentelemetry import trace +from transformers import AutoTokenizer +from transformers.models.xlm_roberta import XLMRobertaConfig + +from lorax_server.adapters import AdapterBatchData +from lorax_server.models import Model +from lorax_server.models.custom_modeling.flash_roberta_modeling import ( + ATTN_K, + ATTN_Q, + ATTN_V, + RobertaEmbeddings, + RobertaLayer, +) +from lorax_server.models.types import FlashEmbeddingClassificationBatch +from lorax_server.pb.generate_pb2 import Embedding +from lorax_server.utils import ( + Weights, + initialize_torch_distributed, + weight_files, +) +from lorax_server.utils.adapter import create_merged_weight_files +from lorax_server.utils.state import FLASH_INFER + +tracer = trace.get_tracer(__name__) + + +class RobertaEncoder: + def __init__(self, prefix, weights, device, dtype, config): + self.layers = [ + RobertaLayer(f"{prefix}.layer.{i}", i, weights, device, dtype, config) + for i in range(config.num_hidden_layers) + ] + + def forward(self, hidden_states, cu_seqlens, max_s, adapter_data): + for layer in self.layers: + hidden_states = layer.forward(hidden_states, cu_seqlens, max_s, adapter_data) + return hidden_states + + +class FlashRobertaModel(torch.nn.Module): + def __init__(self, prefix, weights, device, dtype, config): + super().__init__() + self.config = config + self.embeddings = RobertaEmbeddings(f"{prefix}.embeddings", weights, device, dtype, config) + self.encoder = RobertaEncoder(f"{prefix}.encoder", weights, device, dtype, config) + + def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s, adapter_data): + embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) + encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s, adapter_data) + return encoder_outputs[cu_seqlens[:-1]] + + +class FlashXlmRoberta(Model): + def __init__( + self, + model_id: str, + adapter_id: str, + adapter_source: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + merge_adapter_weights: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashXlmRoberta is only available on GPU") + + self.device = device + self.dtype = dtype + + tokenizer = AutoTokenizer.from_pretrained(model_id) + self.tokenizer = tokenizer + + config = XLMRobertaConfig.from_pretrained(model_id) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + merged_weight_filenames = None + if merge_adapter_weights: + if len(adapter_id) > 0: + logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") + # Need to pass the adapter source here + merged_weight_filenames = create_merged_weight_files( + adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source + ) + self.dynamic_adapter_loading_enabled = False + self.adapter_id = adapter_id + else: + raise ValueError("Cannot merge adapter weights without an adapter_id") + + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + merged_weight_filenames=merged_weight_filenames, + ) + prefix = "roberta" + model = FlashRobertaModel(prefix, weights, device, dtype, config) + + self.hidden_size = config.hidden_size + self.config = config + + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_attention_heads + self.head_size = config.hidden_size // config.num_attention_heads + + if FLASH_INFER: + from lorax_server.utils.flashinfer_attention import create_prefill_state + + self.prefill_state = create_prefill_state(device=device) + + super(FlashXlmRoberta, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + dtype=dtype, + device=device, + adapter_id=adapter_id, + adapter_source=adapter_source, + rank=rank, + world_size=world_size, + requires_padding=False, + ) + + @property + def batch_type(self) -> Type[FlashEmbeddingClassificationBatch]: + return FlashEmbeddingClassificationBatch + + @property + def supports_adapter_loading(self) -> bool: + return True + + @property + def supports_embeddings(self) -> bool: + return True + + @property + def supports_text_generation(self) -> bool: + return False + + def warmup(self, batch: FlashEmbeddingClassificationBatch, max_new_tokens: int) -> int | None: + # Note: This is meant to 1) preallocate the memory by doing a forward pass + # and then just returning the max seqlen since for embeddings we are never generating + _ = self.embed(batch) + return batch.max_s + + def generate_token(self, batch: FlashEmbeddingClassificationBatch) -> None: + if not self.supports_text_generation: + raise NotImplementedError("This model does not support text generation") + return None + + def adapter_target_to_layer(self) -> dict[str, tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "roberta.encoder.layer" + for i, layer in enumerate(self.model.encoder.layers): + layer_weights[(i, ATTN_Q)] = ( + f"{prefix}.{i}.attention.{ATTN_Q}", + layer.attention.query_key_value, + ) + layer_weights[(i, ATTN_K)] = ( + f"{prefix}.{i}.attention.{ATTN_K}", + layer.attention.query_key_value, + ) + layer_weights[(i, ATTN_V)] = ( + f"{prefix}.{i}.attention.{ATTN_V}", + layer.attention.query_key_value, + ) + return layer_weights + + @property + def adapter_layers(self) -> list[str]: + return [ATTN_Q, ATTN_V, ATTN_K] + + @property + def default_traced_adapter_layers(self) -> list[str]: + return [ATTN_Q, ATTN_V] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return len(self.model.encoder.layers) + + def _forward_context( + self, + *, + cu_seqlens: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if not FLASH_INFER: + return nullcontext() + + from lorax_server.utils.flashinfer_attention import use_prefill_state + + return use_prefill_state( + state=(state if state is not None else self.prefill_state), + cu_seqlens=cu_seqlens, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) + + def forward(self, batch: FlashEmbeddingClassificationBatch): + return self.embed(batch) + + @tracer.start_as_current_span("embed") + def embed(self, batch: FlashEmbeddingClassificationBatch) -> Embedding: + adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.layer_to_adapter_weights, False, None) + + with self._forward_context(cu_seqlens=batch.cu_seqlens): + embedding: torch.Tensor = self.model.forward( + input_ids=batch.input_ids, + token_type_ids=batch.token_type_ids, + position_ids=batch.position_ids, + cu_seqlens=batch.cu_seqlens, + max_s=batch.max_s, + adapter_data=adapter_data, + ) + embedding = embedding.reshape(embedding.shape[0], -1)[:, : self.hidden_size] + + cpu_results = embedding.cpu().tolist() + return cpu_results diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 3d045b9a7..62ad44b02 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -6,8 +6,10 @@ import torch from transformers import PreTrainedTokenizerBase +from lorax_server.adapters import AdapterBatchMetadata from lorax_server.pb import generate_pb2 from lorax_server.pb.generate_pb2 import FinishReason +from lorax_server.utils.segments import find_segments from lorax_server.utils.tokenizer import TokenizerManager @@ -140,6 +142,8 @@ class FlashEmbeddingClassificationBatch(ABC): max_s: int size: int + adapter_meta: AdapterBatchMetadata + def __len__(self) -> int: return self.size @@ -173,9 +177,11 @@ def from_pb( position_ids = [] all_token_type_ids = [] cu_seqlens = [0] + adapter_indices_list = [] max_s = 0 cumulative_length = 0 + adapter_set = set() for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): tokenized_input = tokenized_input[-r.truncate :] @@ -192,20 +198,29 @@ def from_pb( request_position_ids = torch.arange(0, input_length, dtype=torch.int32) position_ids.append(request_position_ids) + adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) + adapter_set.add(r.adapter_index) + cumulative_length += input_length if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) final_token_type_ids = np.concatenate(all_token_type_ids, dtype=np.int64) position_ids = torch.cat(position_ids) + adapter_indices = torch.cat(adapter_indices_list) else: input_ids = all_input_ids[0] final_token_type_ids = all_token_type_ids[0] position_ids = position_ids[0] + adapter_indices = adapter_indices_list[0] input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) final_token_type_ids = torch.tensor(final_token_type_ids, dtype=torch.int64, device=device) position_ids = position_ids.to(device) + adapter_indices = adapter_indices.to(dtype=torch.int64, device=device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device) return FlashEmbeddingClassificationBatch( request_ids=[r.id for r in pb.requests], @@ -215,8 +230,27 @@ def from_pb( cu_seqlens=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), max_s=max_s, size=len(batch_tokenized_inputs), + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) + @classmethod + def from_pb_embed( + self, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashEmbeddingClassificationBatch": + return self.from_pb(pb, tokenizer, tokenizers, processor, config, dtype, device) + @classmethod def to_pb_classify(self, batch, predicted_token_classes, confidence_scores) -> generate_pb2.ClassifyResponse: results = []