diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index 8a6bf7cd2..c5b46b82a 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -30,7 +30,9 @@ from lorax_server.utils import flash_attn from lorax_server.utils import paged_attn from lorax_server.utils.layers import ( - FastConv1D, + FastLinear, + TensorParallelAdapterRowLinear, + TensorParallelMultiAdapterLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, @@ -39,9 +41,36 @@ PositionRotaryEmbedding, get_linear, ) - from lorax_server.utils.lora import AdapterBatchData +ATTN_C_ATTN = "attn.c_attn" +ATTN_C_PROJ = "attn.c_proj" +MLP_C_FC = "mlp.c_fc" +MLP_C_PROJ = "mlp.c_proj" +LM_HEAD = "lm_head" + + +def load_attention_multi(config, prefix, weights, fan_in_fan_out=False): + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.c_attn"], + dim=0, + weights=weights, + bias=True, + fan_in_fan_out=fan_in_fan_out, + ) + + +def load_attention(config, prefix, weights, layer_id, layer_names, fan_in_fan_out=False): + base_layer = load_attention_multi(config, prefix, weights, fan_in_fan_out=fan_in_fan_out) + projection_size = config.n_embd + return TensorParallelMultiAdapterLinear.load( + base_layer, layer_id, layer_names, sizes=[ + 3 * projection_size, + ], process_group=weights.process_group + ) + + class FlashGPT2Attention(torch.nn.Module): def __init__(self, config, prefix, weights, layer_id): @@ -81,8 +110,14 @@ def __init__(self, config, prefix, weights, layer_id): self.layer_idx = layer_id self.reorder_and_upcast_attn = config.reorder_and_upcast_attn - self.c_attn = FastConv1D.load(config, prefix=f"{prefix}.c_attn", weights=weights) - self.c_proj = FastConv1D.load(config, prefix=f"{prefix}.c_proj", weights=weights) + self.c_attn = load_attention(config, prefix, weights, layer_id, [ATTN_C_ATTN], fan_in_fan_out=True) + self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + fan_in_fan_out=True, + ), layer_id, ATTN_C_PROJ, process_group=weights.process_group) self.pruned_heads = set() @@ -115,8 +150,9 @@ def forward( slots, input_lengths, max_s, + adapter_data ): - qkv = self.c_attn(hidden_states) + qkv = self.c_attn(hidden_states, adapter_data) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) paged_attn.reshape_and_cache( @@ -154,21 +190,49 @@ def forward( ) attn_output = attn_output.view(-1, self.num_heads * self.head_size) - out = self.c_proj(attn_output) + out = self.c_proj(attn_output, adapter_data) return out class GPT2MLP(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix, weights, layer_id): super().__init__() - self.c_fc = FastConv1D.load(config, prefix=f"{prefix}.c_fc", weights=weights) - self.c_proj = FastConv1D.load(config, prefix=f"{prefix}.c_proj", weights=weights) + + c_fc = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.c_fc", + weights=weights, + bias=True, + fan_in_fan_out=True, + ) + # https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config.n_inner + n_inner = config.n_inner if config.n_inner is not None else config.n_embd * 4 + self.c_fc = TensorParallelMultiAdapterLinear.load( + c_fc, + layer_id, + [MLP_C_FC], + sizes=[n_inner], + process_group=weights.process_group + ) + + self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + fan_in_fan_out=True, + ), layer_id, MLP_C_PROJ, process_group=weights.process_group) + self.act = ACT2FN[config.activation_function] - def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: - hidden_states = self.c_fc(hidden_states) + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + adapter_data: AdapterBatchData, + ) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states, adapter_data) hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) + hidden_states = self.c_proj(hidden_states, adapter_data) return hidden_states @@ -189,7 +253,7 @@ def __init__(self, layer_id, config, weights): prefix=f"{prefix}.ln_2", weights=weights, eps=layer_norm_eps ) - self.mlp = GPT2MLP(config, prefix=f"{prefix}.mlp", weights=weights) + self.mlp = GPT2MLP(config, prefix=f"{prefix}.mlp", weights=weights, layer_id=layer_id) self.process_group = weights.process_group def forward( @@ -201,6 +265,7 @@ def forward( slots, input_lengths, max_s, + adapter_data, ): residual = hidden_states hidden_states, _ = self.ln_1(hidden_states) @@ -212,6 +277,7 @@ def forward( slots, input_lengths, max_s, + adapter_data, ) # residual connection @@ -219,7 +285,7 @@ def forward( residual = hidden_states hidden_states, _ = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states, adapter_data) # residual connection hidden_states = feed_forward_hidden_states + residual @@ -243,7 +309,7 @@ def __init__(self, config, weights): self.wte = TensorParallelEmbedding(prefix="wte", weights=weights) self.wpe = TensorParallelEmbedding(prefix="wpe", weights=weights) - self.layers = nn.ModuleList( + self.h = nn.ModuleList( [ GPT2Block(layer_id, config, weights) for layer_id in range(config.num_hidden_layers) @@ -257,9 +323,9 @@ def __init__(self, config, weights): self.gradient_checkpointing = False - self.head_size = self.layers[0].attn.head_size - self.num_heads = self.layers[0].attn.num_heads - self.num_key_value_heads = self.layers[0].attn.num_key_value_heads + self.head_size = self.h[0].attn.head_size + self.num_heads = self.h[0].attn.num_heads + self.num_key_value_heads = self.h[0].attn.num_key_value_heads def forward( self, @@ -271,12 +337,13 @@ def forward( slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + adapter_data: AdapterBatchData, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds - for i, layer in enumerate(self.layers): + for i, layer in enumerate(self.h): hidden_states = layer( hidden_states, cu_seqlen_prefill, @@ -285,6 +352,7 @@ def forward( slots, input_lengths, max_s, + adapter_data, ) hidden_states, _ = self.ln_f(hidden_states) @@ -294,7 +362,7 @@ def forward( class FlashGPT2ForCausalLM(FlashGPT2PreTrainedModel): def __init__(self, config, weights): super().__init__(config) - self.model = FlashGPT2Model(config, weights) + self.transformer = FlashGPT2Model(config, weights) def forward( self, @@ -309,7 +377,7 @@ def forward( adapter_data: AdapterBatchData, # TODO: plumb this through lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( + hidden_states = self.transformer( input_ids, position_ids, cu_seqlen_prefill, @@ -318,12 +386,13 @@ def forward( slots, input_lengths, max_s, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] # lm_head reuses the weights of the embedding layer # https://github.com/huggingface/transformers/issues/6291 - logits = hidden_states @ self.model.wte.weight.T - logits = logits[:, :self.model.config.vocab_size] + logits = hidden_states @ self.transformer.wte.weight.T + logits = logits[:, :self.transformer.config.vocab_size] return logits diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index 43275ce5e..db28c5d70 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -6,12 +6,17 @@ from opentelemetry import trace from transformers import AutoTokenizer, GPT2Model from tqdm import tqdm -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple from lorax_server.models import FlashCausalLM from lorax_server.models.custom_modeling.flash_gpt2_modeling import ( FlashGPT2ForCausalLM, GPT2Config, + ATTN_C_ATTN, + ATTN_C_PROJ, + MLP_C_FC, + MLP_C_PROJ, + LM_HEAD, ) from lorax_server.utils import ( compute_delta_weight, @@ -26,6 +31,9 @@ tracer = trace.get_tracer(__name__) +ADAPTER_LAYERS = [ATTN_C_ATTN, ATTN_C_PROJ, MLP_C_FC, MLP_C_PROJ] +ROW_PARALLEL = {ATTN_C_PROJ, MLP_C_PROJ} + class FlashGPT2(FlashCausalLM): def __init__( @@ -95,19 +103,40 @@ def __init__( super(FlashGPT2, self).__init__( model=model, tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, + num_layers=len(model.transformer.h), + num_kv_heads=model.transformer.num_key_value_heads, + head_size=model.transformer.head_size, dtype=dtype, device=device, rank=rank, world_size=world_size, ) - - def get_adaptable_weights(self): - # TODO: enable dynamic adapter loading in LoRAX - return {} - + @property def supports_adapter_loading(self) -> bool: return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "transformer.h" + for i, layer in enumerate(self.model.transformer.h): + layer_weights[(i, ATTN_C_ATTN)] = (f"{prefix}.{i}.{ATTN_C_ATTN}", layer.attn.c_attn) + layer_weights[(i, ATTN_C_PROJ)] = (f"{prefix}.{i}.{ATTN_C_PROJ}", layer.attn.c_proj) + + layer_weights[(i, MLP_C_FC)] = (f"{prefix}.{i}.{MLP_C_FC}", layer.mlp.c_fc) + layer_weights[(i, MLP_C_PROJ)] = (f"{prefix}.{i}.{MLP_C_PROJ}", layer.mlp.c_proj) + + # TODO: make Embedding layers adapter-compatible + # layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == LM_HEAD else len(self.model.transformer.h) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 40c8d10f4..a3ed66c12 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -130,7 +130,8 @@ def merge_adapter_weights( # transpose delta weight if necessary # TODO(geoffrey): I believe this is required when using Conv1D layers (gpt2). # We can likely take this out once we've switched to using Linear layers. - if delta_weight.T.shape == model_weights[weight_name].shape: + if (delta_weight.shape != model_weights[weight_name].shape and + delta_weight.T.shape == model_weights[weight_name].shape): delta_weight = delta_weight.T merged_weights[weight_name] = model_weights[weight_name] + delta_weight return merged_weights, processed_adapter_weight_names diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 2c0ee4447..cd9fe0806 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -208,7 +208,12 @@ def forward(self, x: torch.Tensor): return out -def get_linear(weight, bias, quantize): +def get_linear(weight, bias, quantize, fan_in_fan_out=False): + # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out + # Set to True if replacing a Conv1D layer with a Linear layer + if fan_in_fan_out: + weight = weight.T + if quantize is None: linear = FastLinear(weight, bias) elif quantize == "bitsandbytes": @@ -330,22 +335,31 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class TensorParallelColumnLinear(SuperLayer): @classmethod - def load_qkv(cls, config, prefix: str, weights, bias: bool): + def load_qkv(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out=False): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) if bias: raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) return cls(linear) @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - return cls.load_multi(config, [prefix], weights, bias, dim=0) + def load(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = False): + return cls.load_multi( + config, [prefix], weights, bias, dim=0, fan_in_fan_out=fan_in_fan_out) @classmethod - def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + def load_multi( + cls, + config, + prefixes: List[str], + weights, + bias: bool, + dim: int, + fan_in_fan_out=False + ): weight = weights.get_multi_weights_col( prefixes, quantize=config.quantize, dim=dim ) @@ -355,7 +369,7 @@ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) return cls(linear) @@ -514,7 +528,14 @@ def __init__(self, linear, process_group): self.process_group = process_group @classmethod - def load(cls, config, prefix: str, weights, bias: bool): + def load( + cls, + config, + prefix: str, + weights, + bias: bool, + fan_in_fan_out: bool = False + ): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) if bias and weights.process_group.rank() == 0: @@ -523,7 +544,7 @@ def load(cls, config, prefix: str, weights, bias: bool): else: bias = None return cls( - get_linear(weight, bias, config.quantize), + get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out), process_group=weights.process_group, )