Skip to content

Commit

Permalink
LoRAX-compatible GPT-2 (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus authored Dec 7, 2023
1 parent f2f0521 commit 7e803aa
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 42 deletions.
115 changes: 92 additions & 23 deletions server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from lorax_server.utils import flash_attn
from lorax_server.utils import paged_attn
from lorax_server.utils.layers import (
FastConv1D,
FastLinear,
TensorParallelAdapterRowLinear,
TensorParallelMultiAdapterLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
Expand All @@ -39,9 +41,36 @@
PositionRotaryEmbedding,
get_linear,
)

from lorax_server.utils.lora import AdapterBatchData

ATTN_C_ATTN = "attn.c_attn"
ATTN_C_PROJ = "attn.c_proj"
MLP_C_FC = "mlp.c_fc"
MLP_C_PROJ = "mlp.c_proj"
LM_HEAD = "lm_head"


def load_attention_multi(config, prefix, weights, fan_in_fan_out=False):
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.c_attn"],
dim=0,
weights=weights,
bias=True,
fan_in_fan_out=fan_in_fan_out,
)


def load_attention(config, prefix, weights, layer_id, layer_names, fan_in_fan_out=False):
base_layer = load_attention_multi(config, prefix, weights, fan_in_fan_out=fan_in_fan_out)
projection_size = config.n_embd
return TensorParallelMultiAdapterLinear.load(
base_layer, layer_id, layer_names, sizes=[
3 * projection_size,
], process_group=weights.process_group
)



class FlashGPT2Attention(torch.nn.Module):
def __init__(self, config, prefix, weights, layer_id):
Expand Down Expand Up @@ -81,8 +110,14 @@ def __init__(self, config, prefix, weights, layer_id):
self.layer_idx = layer_id
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

self.c_attn = FastConv1D.load(config, prefix=f"{prefix}.c_attn", weights=weights)
self.c_proj = FastConv1D.load(config, prefix=f"{prefix}.c_proj", weights=weights)
self.c_attn = load_attention(config, prefix, weights, layer_id, [ATTN_C_ATTN], fan_in_fan_out=True)
self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
fan_in_fan_out=True,
), layer_id, ATTN_C_PROJ, process_group=weights.process_group)

self.pruned_heads = set()

Expand Down Expand Up @@ -115,8 +150,9 @@ def forward(
slots,
input_lengths,
max_s,
adapter_data
):
qkv = self.c_attn(hidden_states)
qkv = self.c_attn(hidden_states, adapter_data)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)

paged_attn.reshape_and_cache(
Expand Down Expand Up @@ -154,21 +190,49 @@ def forward(
)

attn_output = attn_output.view(-1, self.num_heads * self.head_size)
out = self.c_proj(attn_output)
out = self.c_proj(attn_output, adapter_data)
return out


class GPT2MLP(nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, prefix, weights, layer_id):
super().__init__()
self.c_fc = FastConv1D.load(config, prefix=f"{prefix}.c_fc", weights=weights)
self.c_proj = FastConv1D.load(config, prefix=f"{prefix}.c_proj", weights=weights)

c_fc = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.c_fc",
weights=weights,
bias=True,
fan_in_fan_out=True,
)
# https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config.n_inner
n_inner = config.n_inner if config.n_inner is not None else config.n_embd * 4
self.c_fc = TensorParallelMultiAdapterLinear.load(
c_fc,
layer_id,
[MLP_C_FC],
sizes=[n_inner],
process_group=weights.process_group
)

self.c_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
fan_in_fan_out=True,
), layer_id, MLP_C_PROJ, process_group=weights.process_group)

self.act = ACT2FN[config.activation_function]

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
adapter_data: AdapterBatchData,
) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.c_proj(hidden_states, adapter_data)
return hidden_states


Expand All @@ -189,7 +253,7 @@ def __init__(self, layer_id, config, weights):
prefix=f"{prefix}.ln_2", weights=weights, eps=layer_norm_eps
)

self.mlp = GPT2MLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.mlp = GPT2MLP(config, prefix=f"{prefix}.mlp", weights=weights, layer_id=layer_id)
self.process_group = weights.process_group

def forward(
Expand All @@ -201,6 +265,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_data,
):
residual = hidden_states
hidden_states, _ = self.ln_1(hidden_states)
Expand All @@ -212,14 +277,15 @@ def forward(
slots,
input_lengths,
max_s,
adapter_data,
)

# residual connection
hidden_states = attn_outputs + residual

residual = hidden_states
hidden_states, _ = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states, adapter_data)
# residual connection
hidden_states = feed_forward_hidden_states + residual

Expand All @@ -243,7 +309,7 @@ def __init__(self, config, weights):
self.wte = TensorParallelEmbedding(prefix="wte", weights=weights)
self.wpe = TensorParallelEmbedding(prefix="wpe", weights=weights)

self.layers = nn.ModuleList(
self.h = nn.ModuleList(
[
GPT2Block(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
Expand All @@ -257,9 +323,9 @@ def __init__(self, config, weights):

self.gradient_checkpointing = False

self.head_size = self.layers[0].attn.head_size
self.num_heads = self.layers[0].attn.num_heads
self.num_key_value_heads = self.layers[0].attn.num_key_value_heads
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
self.num_key_value_heads = self.h[0].attn.num_key_value_heads

def forward(
self,
Expand All @@ -271,12 +337,13 @@ def forward(
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
adapter_data: AdapterBatchData,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

for i, layer in enumerate(self.layers):
for i, layer in enumerate(self.h):
hidden_states = layer(
hidden_states,
cu_seqlen_prefill,
Expand All @@ -285,6 +352,7 @@ def forward(
slots,
input_lengths,
max_s,
adapter_data,
)

hidden_states, _ = self.ln_f(hidden_states)
Expand All @@ -294,7 +362,7 @@ def forward(
class FlashGPT2ForCausalLM(FlashGPT2PreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.model = FlashGPT2Model(config, weights)
self.transformer = FlashGPT2Model(config, weights)

def forward(
self,
Expand All @@ -309,7 +377,7 @@ def forward(
adapter_data: AdapterBatchData, # TODO: plumb this through
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
hidden_states = self.transformer(
input_ids,
position_ids,
cu_seqlen_prefill,
Expand All @@ -318,12 +386,13 @@ def forward(
slots,
input_lengths,
max_s,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

# lm_head reuses the weights of the embedding layer
# https://github.com/huggingface/transformers/issues/6291
logits = hidden_states @ self.model.wte.weight.T
logits = logits[:, :self.model.config.vocab_size]
logits = hidden_states @ self.transformer.wte.weight.T
logits = logits[:, :self.transformer.config.vocab_size]
return logits
47 changes: 38 additions & 9 deletions server/lorax_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
from opentelemetry import trace
from transformers import AutoTokenizer, GPT2Model
from tqdm import tqdm
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple

from lorax_server.models import FlashCausalLM
from lorax_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
GPT2Config,
ATTN_C_ATTN,
ATTN_C_PROJ,
MLP_C_FC,
MLP_C_PROJ,
LM_HEAD,
)
from lorax_server.utils import (
compute_delta_weight,
Expand All @@ -26,6 +31,9 @@

tracer = trace.get_tracer(__name__)

ADAPTER_LAYERS = [ATTN_C_ATTN, ATTN_C_PROJ, MLP_C_FC, MLP_C_PROJ]
ROW_PARALLEL = {ATTN_C_PROJ, MLP_C_PROJ}


class FlashGPT2(FlashCausalLM):
def __init__(
Expand Down Expand Up @@ -95,19 +103,40 @@ def __init__(
super(FlashGPT2, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.num_key_value_heads,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

def get_adaptable_weights(self):
# TODO: enable dynamic adapter loading in LoRAX
return {}


@property
def supports_adapter_loading(self) -> bool:
return True

def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}

prefix = "transformer.h"
for i, layer in enumerate(self.model.transformer.h):
layer_weights[(i, ATTN_C_ATTN)] = (f"{prefix}.{i}.{ATTN_C_ATTN}", layer.attn.c_attn)
layer_weights[(i, ATTN_C_PROJ)] = (f"{prefix}.{i}.{ATTN_C_PROJ}", layer.attn.c_proj)

layer_weights[(i, MLP_C_FC)] = (f"{prefix}.{i}.{MLP_C_FC}", layer.mlp.c_fc)
layer_weights[(i, MLP_C_PROJ)] = (f"{prefix}.{i}.{MLP_C_PROJ}", layer.mlp.c_proj)

# TODO: make Embedding layers adapter-compatible
# layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte)
return layer_weights

@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.transformer.h)

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
3 changes: 2 additions & 1 deletion server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def merge_adapter_weights(
# transpose delta weight if necessary
# TODO(geoffrey): I believe this is required when using Conv1D layers (gpt2).
# We can likely take this out once we've switched to using Linear layers.
if delta_weight.T.shape == model_weights[weight_name].shape:
if (delta_weight.shape != model_weights[weight_name].shape and
delta_weight.T.shape == model_weights[weight_name].shape):
delta_weight = delta_weight.T
merged_weights[weight_name] = model_weights[weight_name] + delta_weight
return merged_weights, processed_adapter_weight_names
Expand Down
Loading

0 comments on commit 7e803aa

Please sign in to comment.