Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Embeddings with XLM-RoBERTa and Adapters #656

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ pub enum Config {
T5,
Bert,
Distilbert,
#[serde(rename = "xlm-roberta")]
XlmRoberta,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down
49 changes: 23 additions & 26 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,34 +488,31 @@ impl Infer {
err
})?;

// TODO(travis): support adapters
// let (adapter_source, adapter_parameters) = extract_adapter_params(
// request.parameters.adapter_id.clone(),
// request.parameters.adapter_source.clone(),
// request.parameters.adapter_parameters.clone(),
// );

// let adapter_idx;
// {
// // TODO(travis): can optimize concurrency here using RWLock
// let mut adapter_to_index = self.adapter_to_index.lock().await;
// let adapter_key = adapter_parameters.clone();
// if adapter_to_index.contains_key(&adapter_key) {
// adapter_idx = *adapter_to_index.get(&adapter_key).unwrap();
// } else {
// adapter_idx = adapter_to_index.len() as u32;
// adapter_to_index.insert(adapter_key, adapter_idx);
// }
// }
let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
);

let adapter_idx;
{
// TODO(travis): can optimize concurrency here using RWLock
let mut adapter_to_index = self.adapter_to_index.lock().await;
let adapter_key = adapter_parameters.clone();
if adapter_to_index.contains_key(&adapter_key) {
adapter_idx = *adapter_to_index.get(&adapter_key).unwrap();
} else {
adapter_idx = adapter_to_index.len() as u32;
adapter_to_index.insert(adapter_key, adapter_idx);
}
}

let api_token = request.parameters.api_token.clone();
let adapter = Adapter::new(
AdapterParameters {
adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()],
..Default::default()
},
"hub".to_string(),
0,
None,
adapter_parameters,
adapter_source.unwrap(),
adapter_idx,
api_token,
);

// TODO(travis): robust validation
Expand Down
35 changes: 35 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,9 +921,44 @@ pub(crate) enum CompletionFinishReason {
ToolCalls,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct EmbedParameters {
#[serde(default)]
#[schema(
nullable = true,
default = "null",
example = "arnavgrg/codealpaca-qlora"
)]
pub adapter_id: Option<String>,
#[serde(default)]
#[schema(nullable = true, default = "null", example = "hub")]
pub adapter_source: Option<String>,
#[serde(rename(deserialize = "merged_adapters"))]
#[schema(nullable = true, default = "null")]
pub adapter_parameters: Option<AdapterParameters>,
#[serde(default)]
#[schema(
nullable = true,
default = "null",
example = "<token for private adapters>"
)]
pub api_token: Option<String>,
}

fn default_embed_parameters() -> EmbedParameters {
EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct EmbedRequest {
inputs: String,
#[serde(default = "default_embed_parameters")]
pub parameters: EmbedParameters,
}

#[derive(Serialize, ToSchema)]
Expand Down
17 changes: 12 additions & 5 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ use crate::{
ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice,
ChatMessage, ClassifyRequest, CompatGenerateRequest, CompletionFinishReason, CompletionRequest,
CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, Details, EmbedRequest, EmbedResponse, Entity, ErrorResponse,
FinishReason, FunctionDefinition, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, OpenAiResponseFormat, PrefillToken,
ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeRequest, TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation,
CompletionStreamResponse, Details, EmbedParameters, EmbedRequest, EmbedResponse, Entity,
ErrorResponse, FinishReason, FunctionDefinition, GenerateParameters, GenerateRequest,
GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message,
OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken,
StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, Tool, ToolCall,
ToolChoice, UsageInfo, Validation,
};
use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig};
use axum::extract::Extension;
Expand Down Expand Up @@ -572,6 +573,12 @@ async fn health(
if health.shard_info().supports_embeddings {
let embed_request = EmbedRequest {
inputs: "San Francisco".to_string(),
parameters: EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
},
};
match infer.embed(embed_request).await {
Ok(_) => {}
Expand Down
12 changes: 12 additions & 0 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ def get_model(
if config_dict["architectures"][0] == "DistilBertForTokenClassification":
return FlashDistilBert(model_id, revision=revision, dtype=dtype, classifcation_head=True)

if model_type == "xlm-roberta":
from lorax_server.models.flash_roberta import FlashXlmRoberta

return FlashXlmRoberta(
model_id,
adapter_id,
adapter_source,
revision=revision,
dtype=dtype,
merge_adapter_weights=merge_adapter_weights,
)

flash_causal_lm_kwargs = dict(
quantize=quantize,
compile=compile,
Expand Down
145 changes: 145 additions & 0 deletions server/lorax_server/models/custom_modeling/flash_roberta_modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
from torch import nn
from transformers.activations import ACT2FN

from lorax_server.utils.flash_attn import attention
from lorax_server.utils.layers import (
FastLayerNorm,
TensorParallelColumnLinear,
TensorParallelMultiAdapterLinear,
)

ATTN_Q = "self.query"
ATTN_K = "self.key"
ATTN_V = "self.value"


class RobertaEmbeddings:
def __init__(self, prefix, weights, device, dtype, config):
self.word_embeddings_weight = weights.get_tensor(f"{prefix}.word_embeddings.weight").to(dtype).to(device)
self.token_type_embeddings_weight = (
weights.get_tensor(f"{prefix}.token_type_embeddings.weight").to(dtype).to(device)
)

if config.position_embedding_type == "absolute":
self.position_embeddings_weight = (
weights.get_tensor(f"{prefix}.position_embeddings.weight").to(dtype).to(device)
)
else:
raise NotImplementedError("FlashRoberta only supports absolute position embeddings")
self.pad_token_id = config.pad_token_id

self.layer_norm = FastLayerNorm.load(prefix=f"{prefix}.LayerNorm", weights=weights, eps=config.layer_norm_eps)

def forward(self, input_ids, token_type_ids, position_ids):
# position numbers begin at pad_token_id + 1
# see transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
position_ids += self.pad_token_id + 1

inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight)
token_type_embeds = nn.functional.embedding(token_type_ids, self.token_type_embeddings_weight)
position_embeds = nn.functional.embedding(position_ids, self.position_embeddings_weight)

inputs_embeds += position_embeds

embeddings, _ = self.layer_norm.forward(inputs_embeds, token_type_embeds)
return embeddings


class RobertaAttention:
def __init__(self, prefix, layer_id, weights, device, dtype, config):
self.query_key_value = RobertaAttention.load_attention(config, prefix, weights, layer_id)

self.dense_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
self.dense_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)

self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps
)

self.head_size = config.hidden_size // config.num_attention_heads
self.softmax_scale = self.head_size**-0.5
self.num_heads = config.num_attention_heads
self.layer_id = layer_id

@staticmethod
def load_attention(config, prefix, weights, layer_id):
config.quantize = None
base_layer = RobertaAttention.load_attention_multi(config, prefix, weights)
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[ATTN_Q, ATTN_K, ATTN_V],
sizes=[
config.hidden_size,
config.hidden_size,
config.hidden_size,
],
process_group=weights.process_group,
)

@staticmethod
def load_attention_multi(config, prefix, weights):
prefixes = [f"{prefix}.{ATTN_Q}", f"{prefix}.{ATTN_K}", f"{prefix}.{ATTN_V}"]
return TensorParallelColumnLinear.load_multi(
config,
prefixes=prefixes,
dim=0,
weights=weights,
bias=True,
)

def forward(self, hidden_states, cu_seqlens, max_s, adapter_data):
residual = hidden_states

qkv = self.query_key_value(hidden_states, adapter_data)
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(self.num_heads, dim=1)

attn_output = attention(q, k, v, None, None, cu_seqlens, max_s, self.softmax_scale, causal=False)

hidden_states = torch.addmm(
self.dense_bias,
attn_output.view(-1, self.num_heads * self.head_size),
self.dense_weight,
)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)

return hidden_states


class RobertaLayer:
def __init__(self, prefix, layer_id, weights, device, dtype, config):
self.attention = RobertaAttention(f"{prefix}.attention", layer_id, weights, device, dtype, config)

self.intermediate_weight = weights.get_tensor(f"{prefix}.intermediate.dense.weight").T.to(dtype).to(device)
self.intermediate_bias = weights.get_tensor(f"{prefix}.intermediate.dense.bias").to(dtype).to(device)

act = config.hidden_act
self.intermediate_act_fn = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none",
)
)

self.output_weight = weights.get_tensor(f"{prefix}.output.dense.weight").T.to(dtype).to(device)
self.output_bias = weights.get_tensor(f"{prefix}.output.dense.bias").to(dtype).to(device)
self.layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.output.LayerNorm", weights=weights, eps=config.layer_norm_eps
)

def forward(self, hidden_states, cu_seqlens, max_s, adapter_data):
hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s, adapter_data)
residual = hidden_states

hidden_states = torch.addmm(self.intermediate_bias, hidden_states, self.intermediate_weight)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = torch.addmm(
self.output_bias,
hidden_states,
self.output_weight,
)
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
return hidden_states
Loading
Loading