From 107be9a9aa0da505c5854c6dbae18e0571a10488 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 5 Nov 2024 10:26:35 -0800 Subject: [PATCH] Check for key in lora weights --- server/lorax_server/utils/layers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index d25be2128..0feaae609 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -77,8 +77,10 @@ def forward_layer_type( can_vectorize = data is not None and data.can_vectorize(self.process_group) # Triton Punica kernels + key = (layer_type, self.layer_id) if ( adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled + and key in adapter_data.layer_to_lora_weights and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size and can_vectorize ): @@ -89,7 +91,7 @@ def forward_layer_type( y_offset = None y_slice_size = None - lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[(layer_type, self.layer_id)] + lora_a_weights, lora_b_weights = adapter_data.layer_to_lora_weights[key] adapter_data.punica_wrapper.add_lora( result, input,