Skip to content

Commit

Permalink
Use SGMV for prefill BGMV for decode (#464)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored May 14, 2024
1 parent a7e8175 commit 7306d49
Show file tree
Hide file tree
Showing 24 changed files with 707 additions and 438 deletions.
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"numeric": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"atomic": "cpp"
"atomic": "cpp",
"__locale": "cpp",
"ios": "cpp"
}
}
145 changes: 118 additions & 27 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.sgmv import MAX_RANK_CUSTOM, get_tmp_tensors, orient_for_rank, pad_rank
from lorax_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)

if TYPE_CHECKING:
from lorax_server.models.model import Model
Expand Down Expand Up @@ -87,15 +94,49 @@ def __init__(
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False

# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self.weights_a = torch.stack(weights_a)
self._weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
self.weights_b = torch.stack(weights_b)
self._weights_b = torch.stack(weights_b)

self.adapter_config = adapter_config

@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a

@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b

@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a

@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b

def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
return BatchLoraWeights
Expand Down Expand Up @@ -162,20 +203,27 @@ def load(
@dataclass
class RankSegments:
rank: int
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor

lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor

# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor

# decode (bgmv)
indices: torch.Tensor


@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool

def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
Expand All @@ -188,36 +236,63 @@ def key(cls) -> str:
return LORA

@classmethod
def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights":
def load(
self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata, prefill: bool
) -> "BatchLoraWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}

first_weights = list(adapter_weights.values())[0]
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights)

if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

adapter_index_configs = {
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
Expand All @@ -226,24 +301,40 @@ def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMet

rank_data = {}
for rank, indices in rank_indices.items():
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device)
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None

if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
else:
rank_indices = set(indices)
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()]
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices]
batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device)

rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr_indices,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=meta.adapter_segments[indices],
segment_ends=meta.adapter_segments[[i + 1 for i in indices]],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)

return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)


Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def __call__(self, x, lm_head):
return lm_head(x)

@classmethod
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchMedusaWeights":
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool
) -> "BatchMedusaWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)}
default_medusa = adapter_weights.get(0)

Expand Down
40 changes: 30 additions & 10 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def key(cls) -> str:
pass

@abstractclassmethod
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchAdapterWeights":
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata", prefill: bool
) -> "BatchAdapterWeights":
pass


Expand All @@ -70,15 +72,15 @@ def max_speculative_tokens(self) -> int:
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0

def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights]:
def get_data(self, meta: AdapterBatchMetadata, prefill: bool) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
adapter_batch_types[adapter_weights.get_batch_type()][adapter_index] = adapter_weights

batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta)
batch_data[batch_type.key()] = batch_type.load(adapter_weights, meta, prefill)
return batch_data


Expand All @@ -89,22 +91,40 @@ class AdapterBatchData:
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]

prefill: bool

@staticmethod
def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]) -> "AdapterBatchData":
def from_meta(
meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights], prefill: bool
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(meta)
return AdapterBatchData(meta=meta, data=data)
data[k] = v.get_data(meta, prefill)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)

def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
lora_data = self.data.get(LORA)
if lora_data is None:
return set()
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get(LORA)
if lora_data is None:
continue

for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)

return ranks

def layer_names(self) -> Set[str]:
return set(self.data.keys())

return set(rank_data.rank for layer_data in self.data.values() for rank_data in lora_data.rank_data.values())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys

@property
def max_rank(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option

# Assign pointers to LoRA weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights)
# Use prefill=True in all cases to force use of SGMV, as the batch is heterogenous
adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.layer_to_adapter_weights, prefill=True)

logits, past = self.forward(
batch.input_ids,
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
self.model,
self.device,
self.adapter_layers,
self.default_traced_adapter_layers,
max_total_tokens,
self.sliding_window_blocks,
)
Expand Down Expand Up @@ -951,7 +952,7 @@ def generate_token(

# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(adapter_meta, self.batched_lora_weights)
adapter_data = AdapterBatchData.from_meta(adapter_meta, self.layer_to_adapter_weights, prefill)

out, speculative_logits = self._try_generate_token(batch, adapter_data)

Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/flash_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)

Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/flash_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

@property
def default_traced_adapter_layers(self) -> List[str]:
return [ATTN_WQKV]

def get_num_layers_for_type(self, layer_type: str) -> int:
return len(self.model.model.layers)

Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]

def get_num_layers_for_type(self, layer_type: str) -> int:
return len(self.model.model.layers)

Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

@property
def default_traced_adapter_layers(self) -> List[str]:
return [ATTN_C_ATTN]

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.transformer.h)

Expand Down
Loading

0 comments on commit 7306d49

Please sign in to comment.