Skip to content

Commit

Permalink
Optimize SGMV kernel code path to reduce mallocs (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Dec 18, 2023
1 parent af59e54 commit 5080877
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
9 changes: 7 additions & 2 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,12 @@ def forward_layer_type(
end_idx: int,
) -> torch.Tensor:
data = adapter_data.data.get(layer_type)

if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
proj = torch.zeros_like(result[:, start_idx:end_idx])
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result

for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
Expand Down Expand Up @@ -433,7 +437,8 @@ def forward_layer_type(
self.layer_id,
)

result[:, start_idx:end_idx] += proj
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
):
# [num_layers, hidden_size, r]
weights_a = [
orient_for_rank(w, adapter_config.r)
orient_for_rank(w, adapter_config.r).contiguous()
for w in weights_a
]
self.weights_a = torch.stack(weights_a)
Expand Down
20 changes: 15 additions & 5 deletions server/lorax_server/utils/sgmv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
import os
import warnings
from typing import Tuple
Expand Down Expand Up @@ -84,6 +85,17 @@ def _add_lora_sgmv_cutlass_legacy(
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)


@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)


@lru_cache(maxsize=1)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)


def lora_a_sgmv_cutlass(
x: torch.Tensor,
wa_ptr: torch.Tensor,
Expand All @@ -94,13 +106,11 @@ def lora_a_sgmv_cutlass(
) -> Tuple[torch.Tensor, torch.Tensor]:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
tmp1 = get_tmp_tensor(x.device)
tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
else:
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v, tmp

Expand Down

0 comments on commit 5080877

Please sign in to comment.