diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 4c883c214..35d39689d 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -404,8 +404,12 @@ def forward_layer_type( end_idx: int, ) -> torch.Tensor: data = adapter_data.data.get(layer_type) + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): - proj = torch.zeros_like(result[:, start_idx:end_idx]) + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result for r, rank_segments in data.rank_data.items(): lora_a_ptr = rank_segments.lora_a_ptr @@ -433,7 +437,8 @@ def forward_layer_type( self.layer_id, ) - result[:, start_idx:end_idx] += proj + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj else: for adapter_index in adapter_data.meta.adapter_set: if data is not None and data.has_adapter(adapter_index): diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 05b2e1bc3..72eea7e2e 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -85,7 +85,7 @@ def __init__( ): # [num_layers, hidden_size, r] weights_a = [ - orient_for_rank(w, adapter_config.r) + orient_for_rank(w, adapter_config.r).contiguous() for w in weights_a ] self.weights_a = torch.stack(weights_a) diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 99853cc48..874ca20a8 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -1,3 +1,4 @@ +from functools import lru_cache import os import warnings from typing import Tuple @@ -84,6 +85,17 @@ def _add_lora_sgmv_cutlass_legacy( _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) +@lru_cache(maxsize=1) +def get_tmp_tensor(device: torch.device) -> torch.Tensor: + return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) + + +@lru_cache(maxsize=1) +def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: + tmp_size = _kernels.sgmv_cutlass_tmp_size(size) + return torch.empty((tmp_size,), dtype=torch.uint8, device=device) + + def lora_a_sgmv_cutlass( x: torch.Tensor, wa_ptr: torch.Tensor, @@ -94,13 +106,11 @@ def lora_a_sgmv_cutlass( ) -> Tuple[torch.Tensor, torch.Tensor]: v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: - tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) - tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) + tmp1 = get_tmp_tensor(x.device) + tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device) _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) else: - tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + tmp = get_tmp_tensor_for_size(wa_ptr.size(0), x.device) _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) return v, tmp