Skip to content

Commit

Permalink
Added Bloom dynamic adapter loading (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 17, 2024
1 parent 76fcc5d commit fb3883c
Show file tree
Hide file tree
Showing 24 changed files with 453 additions and 232 deletions.
56 changes: 54 additions & 2 deletions server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.distributed

from typing import Optional, Type
from typing import Dict, List, Optional, Tuple, Type

from transformers import (
AutoTokenizer,
Expand All @@ -10,6 +10,11 @@
)

from lorax_server.models.custom_modeling.bloom_modeling import (
ATTN_DENSE,
ATTN_QKV,
LM_HEAD,
MLP_DENSE_4H_TO_H,
MLP_DENSE_H_TO_4H,
BloomForCausalLM,
)
from lorax_server.models import CausalLM
Expand All @@ -21,6 +26,10 @@
Weights,
)
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.lora import AdapterBatchData

ADAPTER_LAYERS = [ATTN_QKV, ATTN_DENSE, MLP_DENSE_H_TO_4H, MLP_DENSE_4H_TO_H]
ROW_PARALLEL = {ATTN_DENSE, MLP_DENSE_4H_TO_H}


class BloomCausalLMBatch(CausalLMBatch):
Expand Down Expand Up @@ -89,6 +98,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
Expand All @@ -98,20 +108,62 @@ def __init__(
world_size=world_size,
)

self.dynamic_adapter_loading_enabled = True


@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch

@property
def has_adapter_data(self) -> bool:
return True

def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self,
input_ids,
attention_mask,
position_ids,
past_key_values: Optional = None,
adapter_data: Optional[AdapterBatchData] = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
adapter_data=adapter_data,
)

logits = outputs.logits
return logits, outputs.past_key_values

@property
def supports_adapter_loading(self) -> bool:
return True

def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}

prefix = "transformer.h"
for i, layer in enumerate(self.model.transformer.h):
layer_weights[(i, ATTN_QKV)] = (f"{prefix}.{i}.self_attention.query_key_value", layer.self_attention.query_key_value)
layer_weights[(i, ATTN_DENSE)] = (f"{prefix}.{i}.self_attention.dense", layer.self_attention.dense)

layer_weights[(i, MLP_DENSE_H_TO_4H)] = (f"{prefix}.{i}.mlp.dense_h_to_4h", layer.mlp.dense_h_to_4h)
layer_weights[(i, MLP_DENSE_4H_TO_H)] = (f"{prefix}.{i}.mlp.dense_4h_to_h", layer.mlp.dense_4h_to_h)

# TODO: make Embedding layers adapter-compatible
# layer_weights[(0, LM_HEAD)] = ("transformer.wte", self.model.transformer.wte)
return layer_weights

@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.transformer.h)

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
86 changes: 77 additions & 9 deletions server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import json
import torch
import inspect
Expand All @@ -17,6 +18,8 @@
from lorax_server.pb import generate_pb2
from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -46,7 +49,7 @@ class CausalLMBatch(Batch):
stopping_criterias: List[StoppingCriteria]

# Adapter metadata for each request
adapter_indices: torch.Tensor
adapter_meta: AdapterBatchMetadata

# Metadata used for padding
max_input_length: int
Expand Down Expand Up @@ -87,6 +90,7 @@ def from_pb(
padding_right_offset = 0
max_decode_tokens = 0
adapter_indices_list = []
adapter_set = set()
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
req_inputs = tokenizers.get_inputs(r, tokenizer)
Expand All @@ -102,6 +106,7 @@ def from_pb(
padding_right_offset, stopping_criteria.max_new_tokens
)
adapter_indices_list.append(r.adapter_index)
adapter_set.add(r.adapter_index)

adapter_indices = torch.tensor(adapter_indices_list, dtype=torch.int64, device=device)

Expand Down Expand Up @@ -135,6 +140,9 @@ def from_pb(

max_tokens = len(inputs) * (max_input_length + max_decode_tokens)

adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device)

return cls(
batch_id=pb.id,
requests=pb.requests,
Expand All @@ -152,7 +160,12 @@ def from_pb(
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
adapter_indices=adapter_indices,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)

@tracer.start_as_current_span("filter")
Expand All @@ -173,7 +186,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
all_input_ids = []
max_input_length = 0

# TODO(travis): adapter indices
adapter_set = set()

next_token_choosers = []
stopping_criterias = []
Expand Down Expand Up @@ -206,9 +219,12 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
new_padding_right_offset, remaining_decode_tokens
)

adapter_set.add(self.requests[idx].adapter_index)

# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
adapter_indices = self.adapter_meta.adapter_indices[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
Expand Down Expand Up @@ -239,6 +255,10 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:

max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens

device = self.input_ids.device
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device)

self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
Expand All @@ -252,6 +272,12 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)

return self

Expand Down Expand Up @@ -285,6 +311,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
past_key_values = []
adapter_indices = None

total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
cumulative_adapter_indices_size = 0

# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
Expand Down Expand Up @@ -319,10 +351,15 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids

# Create adapter indices
if adapter_indices is None:
adapter_indices = batch.adapter_indices.new_empty((total_batch_size,))
adapter_indices[start_index:end_index] = batch.adapter_indices
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0]
adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)

# Update adapter segments
adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices)

# Create padded tensor
if attention_mask is None:
Expand Down Expand Up @@ -444,6 +481,8 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":

past_key_values.append([padded_past_keys, padded_past_values])

adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

return cls(
batch_id=batches[0].batch_id,
requests=requests,
Expand All @@ -462,7 +501,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
adapter_indices=adapter_indices,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)

def __len__(self):
Expand Down Expand Up @@ -523,24 +567,36 @@ def __init__(
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

self.dynamic_adapter_loading_enabled = False

@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch

@property
def has_adapter_data(self) -> bool:
return False

def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self,
input_ids,
attention_mask,
position_ids,
past_key_values: Optional = None,
adapter_data: Optional[AdapterBatchData] = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
Expand All @@ -552,6 +608,8 @@ def forward(
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
if self.has_adapter_data:
kwargs["adapter_data"] = adapter_data

outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values
Expand All @@ -563,11 +621,16 @@ def generate_token(
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]

# Assign pointers to LoRA weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights)

logits, past = self.forward(
batch.input_ids,
attention_mask,
batch.position_ids,
batch.past_key_values,
adapter_data,
)

# Results
Expand All @@ -586,6 +649,8 @@ def generate_token(
batch.all_input_ids,
)

next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(len(batch))

# For each member of the batch
for i, (
request,
Expand Down Expand Up @@ -684,6 +749,7 @@ def generate_token(
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
next_adapter_indices[i] = request.adapter_index

# We finished all generations in the batch; there is no next batch
if stopped:
Expand All @@ -703,4 +769,6 @@ def generate_token(
# Update past key values
batch.past_key_values = past

batch.adapter_meta.adapter_indices = next_adapter_indices

return generations, batch
Loading

0 comments on commit fb3883c

Please sign in to comment.