Skip to content

Commit

Permalink
Renamed add_input to add_inputs in punica_cpu.py.
Browse files Browse the repository at this point in the history
Signed-off-by: Oleg Mosalov <[email protected]>
  • Loading branch information
mosalov committed Dec 17, 2024
1 parent 2840445 commit 41c518f
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions vllm/lora/punica_wrapper/punica_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _expand_prefill(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
Expand All @@ -64,17 +64,17 @@ def _expand_prefill(
w_t_all,
y,
*self.prefill_metadata,
add_input,
add_inputs,
)

def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)

def _expand_slice_prefill(
self,
Expand All @@ -83,7 +83,7 @@ def _expand_slice_prefill(
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_input: bool,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
Expand All @@ -95,7 +95,7 @@ def _expand_slice_prefill(
*self.prefill_metadata,
y_offset,
y_slice_size,
add_input,
add_inputs,
)

def _expand_slice_decode(
Expand All @@ -105,10 +105,10 @@ def _expand_slice_decode(
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_input: bool,
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
y_slice_size, add_inputs)

def _apply_expand(
self,
Expand All @@ -117,7 +117,7 @@ def _apply_expand(
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_input: bool = True,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
Expand All @@ -128,7 +128,7 @@ def _apply_expand(
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)

def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
Expand Down Expand Up @@ -181,7 +181,7 @@ def add_expand(self,
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
add_inputs=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Expand All @@ -200,7 +200,7 @@ def add_expand(self,
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_input (bool): Defaults to True.
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
Expand All @@ -215,7 +215,7 @@ def add_expand(self,
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_input=add_input,
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
Expand All @@ -224,7 +224,7 @@ def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
add_inputs: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Expand All @@ -236,13 +236,13 @@ def add_lora_embedding(self,
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_input (bool): Default to True.
add_inputs (bool): Default to True.
"""

# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_input)
expand_fun(y, x, lora_b_stacked, add_inputs)

def add_lora_linear(self,
y: torch.Tensor,
Expand Down Expand Up @@ -298,7 +298,7 @@ def add_lora_linear(self,
lora_b_stacked,
None,
output_slices,
add_input=True,
add_inputs=True,
**kwargs)

def add_lora_logits(self,
Expand Down

0 comments on commit 41c518f

Please sign in to comment.