Skip to content

Commit

Permalink
Check for key in lora weights
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 5, 2024
1 parent 848b4c7 commit 107be9a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def forward_layer_type(
can_vectorize = data is not None and data.can_vectorize(self.process_group)

# Triton Punica kernels
key = (layer_type, self.layer_id)
if (
adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled
and key in adapter_data.layer_to_lora_weights
and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size
and can_vectorize
):
Expand All @@ -89,7 +91,7 @@ def forward_layer_type(
y_offset = None
y_slice_size = None

lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)]
lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[key]
adapter_data.punica_wrapper.add_lora(
result,
input,
Expand Down

0 comments on commit 107be9a

Please sign in to comment.