Skip to content

Commit

Permalink
Merge branch 'main' into medusa
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 authored Dec 19, 2023
2 parents 505ad9d + 9ae65b3 commit a1d01c6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
4 changes: 2 additions & 2 deletions server/lorax_server/utils/gptq/custom_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def kernel_call():
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40
kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40
)
except triton.compiler.OutOfResources:
except triton.OutOfResources:
return (float("inf"), float("inf"), float("inf"))

def run(self, *args, **kwargs):
Expand Down
9 changes: 7 additions & 2 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,12 @@ def forward_layer_type(
end_idx: int,
) -> torch.Tensor:
data = adapter_data.data.get(layer_type)

if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
proj = torch.zeros_like(result[:, start_idx:end_idx])
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result

for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
Expand Down Expand Up @@ -461,7 +465,8 @@ def forward_layer_type(
self.layer_id,
)

result[:, start_idx:end_idx] += proj
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
Expand Down
5 changes: 4 additions & 1 deletion server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ def __init__(
adapter_config: LoraConfig,
):
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, adapter_config.r) for w in weights_a]
weights_a = [
orient_for_rank(w, adapter_config.r).contiguous()
for w in weights_a
]
self.weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
Expand Down
20 changes: 15 additions & 5 deletions server/lorax_server/utils/sgmv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
import os
import warnings
from typing import Tuple
Expand Down Expand Up @@ -87,6 +88,17 @@ def _add_lora_sgmv_cutlass_legacy(
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)


@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)


@lru_cache(maxsize=1)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)


def lora_a_sgmv_cutlass(
x: torch.Tensor,
wa_ptr: torch.Tensor,
Expand All @@ -97,13 +109,11 @@ def lora_a_sgmv_cutlass(
) -> Tuple[torch.Tensor, torch.Tensor]:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
tmp1 = get_tmp_tensor(x.device)
tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
else:
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v, tmp

Expand Down

0 comments on commit a1d01c6

Please sign in to comment.