From 41db5b00775026aa672e86d4faca4c73df06066a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 12 Mar 2024 22:22:51 -0700 Subject: [PATCH] Fixed --- server/lorax_server/utils/layers.py | 25 +++++++++++++++++++++++-- server/lorax_server/utils/lora.py | 6 ++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index e3064d588..d560aff23 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -507,7 +507,6 @@ def __init__(self, base_layer, layer_id, process_group): self.base_layer = base_layer self.layer_id = layer_id self.process_group = process_group - self.use_sgmv = False def forward_layer_type( self, @@ -521,7 +520,7 @@ def forward_layer_type( data = adapter_data.data.get(layer_type) if has_sgmv() and data is not None and data.can_vectorize(self.process_group): - if not self.use_sgmv: + if self.process_group.rank() == 0 and self.layer_id == 0: print("!!! USE SGMV") if end_idx - start_idx != result.shape[1]: proj = torch.zeros_like(result[:, start_idx:end_idx]) @@ -541,9 +540,14 @@ def forward_layer_type( self.layer_id, r, ) + if self.process_group.rank() == 0 and self.layer_id == 0: + print("V", v.shape, v.norm().item()) if self.process_group.size() > 1: v = self.collect_lora_a(v) + + if self.process_group.rank() == 0 and self.layer_id == 0: + print("V collect", v.shape, v.norm().item()) lora_b_sgmv_cutlass( proj, @@ -554,6 +558,9 @@ def forward_layer_type( rank_segments.segment_ends, self.layer_id, ) + + if self.process_group.rank() == 0 and self.layer_id == 0: + print("proj", proj.shape, proj.norm().item()) if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj @@ -577,12 +584,26 @@ def forward_lora( lora_b = data.lora_b[adapter_index][self.layer_id, :, :] lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + if self.process_group.rank() == 0 and self.layer_id == 0: + print("lora_a", lora_a.shape, lora_a.norm().item()) + print("lora_b", lora_b.shape, lora_b.norm().item()) + a_out = input @ lora_a + if self.process_group.rank() == 0 and self.layer_id == 0: + print("V", a_out.shape, a_out.norm().item()) if self.process_group.size() > 1: a_out = self.collect_lora_a(a_out) + if self.process_group.rank() == 0 and self.layer_id == 0: + print("V collect", a_out.shape, a_out.norm().item()) + result = (a_out @ lora_b) * adapter_mask + + if self.process_group.rank() == 0 and self.layer_id == 0: + print("proj", result.shape, result.norm().item()) + return result def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 825f28138..7511b384f 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -98,6 +98,9 @@ def __init__( weights_b: List[torch.Tensor], adapter_config: LoraConfig, ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + # [num_layers, hidden_size, r] weights_a = [ orient_for_rank(w, w.size(1)).contiguous() @@ -187,8 +190,7 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData: for segment_idx, adapter_idx in enumerate(segment_indices): if adapter_idx not in self.lora_weights: continue - print("rank indices for segment", self.lora_weights[adapter_idx].weights_b.shape) - rank_indices[self.lora_weights[adapter_idx].weights_b.size(1)].append(segment_idx) + rank_indices[self.lora_weights[adapter_idx].lora_a_r].append(segment_idx) rank_data = {} for rank, indices in rank_indices.items():